jaxopt.AndersonWrapper
- class jaxopt.AndersonWrapper(solver, history_size=5, mixing_frequency=None, beta=1.0, ridge=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
Wrapper for accelerating JAXopt solvers.
Note that the internal solver state can be accessed via the
aux
attribute of AndersonState.- Parameters
solver (IterativeSolver) –
history_size (int) –
mixing_frequency (int) –
beta (float) –
ridge (float) –
verbose (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- solver
solver object to accelerate. Must exhibit init() and update() methods.
- Type
jaxopt._src.base.IterativeSolver
- history_size
size of history. Affect memory cost. (default: 5).
- Type
int
- mixing_frequency
frequency of Anderson updates. (default:
history_size
). Only one everymixing_frequency
updates uses Anderson, while the other updates use regular fixed point iterations.- Type
int
- beta
momentum in Anderson updates. (default: 1).
- Type
float
- ridge
ridge regularization in solver. Consider increasing this value if the solver returns
NaN
.- Type
float
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
bool
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”)
- Type
Union[str, bool]
- __init__(solver, history_size=5, mixing_frequency=None, beta=1.0, ridge=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
solver (IterativeSolver) –
history_size (int) –
mixing_frequency (Optional[int]) –
beta (float) –
ridge (float) –
verbose (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(solver[, history_size, ...])attribute_names
()attribute_values
()init_state
(init_params, *args, **kwargs)- rtype
AndersonWrapperState
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(params, *args, **kwargs)Optimality function mapping compatible with
@custom_root
.run
(init_params, *args, **kwargs)Runs the optimization loop.
update
(params, state, *args, **kwargs)Perform one step of Anderson acceleration over the internal solver update.
Attributes
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- update(params, state, *args, **kwargs)[source]
Perform one step of Anderson acceleration over the internal solver update.
The reset_state attribute is used to update the internal solver state after the Anderson step.
- Parameters
params – parameters optimized by solver. Only its pytree structure matters (content unused).
state – AndersonWrapperState Crucially, state.params_history and state.residuals_history are the sequences used to generate next iterate. Note: state.solver_state is the internal solver state.
args – additional parameters passed to
update
method of internal solver Note: sometimes those are hyper-parameters of the solver, but if the solver is a Jaxopt solver they will be forwarded to the underlying function being optimizedkwargs – additional parameters passed to
update
method of internal solver Note: sometimes those are hyper-parameters of the solver, but if the solver is a Jaxopt solver they will be forwarded to the underlying function being optimized
- Return type
OptStep