jaxopt.AndersonWrapper
- class jaxopt.AndersonWrapper(solver, history_size=5, mixing_frequency=None, beta=1.0, ridge=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
- Wrapper for accelerating JAXopt solvers. - Note that the internal solver state can be accessed via the - auxattribute of AndersonState.- Parameters
- solver (IterativeSolver) – 
- history_size (int) – 
- mixing_frequency (int) – 
- beta (float) – 
- ridge (float) – 
- verbose (bool) – 
- implicit_diff (bool) – 
- implicit_diff_solve (Optional[Callable]) – 
- jit (Union[str, bool]) – 
- unroll (Union[str, bool]) – 
 
 - solver
- solver object to accelerate. Must exhibit init() and update() methods. - Type
- jaxopt._src.base.IterativeSolver 
 
 - history_size
- size of history. Affect memory cost. (default: 5). - Type
- int 
 
 - mixing_frequency
- frequency of Anderson updates. (default: - history_size). Only one every- mixing_frequencyupdates uses Anderson, while the other updates use regular fixed point iterations.- Type
- int 
 
 - beta
- momentum in Anderson updates. (default: 1). - Type
- float 
 
 - ridge
- ridge regularization in solver. Consider increasing this value if the solver returns - NaN.- Type
- float 
 
 - verbose
- whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit. - Type
- bool 
 
 - implicit_diff
- whether to enable implicit diff or autodiff of unrolled iterations. - Type
- bool 
 
 - implicit_diff_solve
- the linear system solver to use. - Type
- Optional[Callable] 
 
 - jit
- whether to JIT-compile the optimization loop (default: “auto”). - Type
- Union[str, bool] 
 
 - unroll
- whether to unroll the optimization loop (default: “auto”) - Type
- Union[str, bool] 
 
 - __init__(solver, history_size=5, mixing_frequency=None, beta=1.0, ridge=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
- solver (IterativeSolver) – 
- history_size (int) – 
- mixing_frequency (Optional[int]) – 
- beta (float) – 
- ridge (float) – 
- verbose (bool) – 
- implicit_diff (bool) – 
- implicit_diff_solve (Optional[Callable]) – 
- jit (Union[str, bool]) – 
- unroll (Union[str, bool]) – 
 
- Return type
- None 
 
 - Methods - __init__(solver[, history_size, ...])- attribute_names()- attribute_values()- init_state(init_params, *args, **kwargs)- rtype
- AndersonWrapperState
 - l2_optimality_error(params, *args, **kwargs)- Computes the L2 optimality error. - optimality_fun(params, *args, **kwargs)- Optimality function mapping compatible with - @custom_root.- run(init_params, *args, **kwargs)- Runs the optimization loop. - update(params, state, *args, **kwargs)- Perform one step of Anderson acceleration over the internal solver update. - Attributes - l2_optimality_error(params, *args, **kwargs)
- Computes the L2 optimality error. 
 - optimality_fun(params, *args, **kwargs)[source]
- Optimality function mapping compatible with - @custom_root.
 - run(init_params, *args, **kwargs)
- Runs the optimization loop. - Parameters
- init_params ( - Any) – pytree containing the initial parameters.
- *args – additional positional arguments to be passed to the update method. 
- **kwargs – additional keyword arguments to be passed to the update method. 
 
- Return type
- OptStep
- Returns
- (params, state) 
 
 - update(params, state, *args, **kwargs)[source]
- Perform one step of Anderson acceleration over the internal solver update. - The reset_state attribute is used to update the internal solver state after the Anderson step. - Parameters
- params – parameters optimized by solver. Only its pytree structure matters (content unused). 
- state – AndersonWrapperState Crucially, state.params_history and state.residuals_history are the sequences used to generate next iterate. Note: state.solver_state is the internal solver state. 
- args – additional parameters passed to - updatemethod of internal solver Note: sometimes those are hyper-parameters of the solver, but if the solver is a Jaxopt solver they will be forwarded to the underlying function being optimized
- kwargs – additional parameters passed to - updatemethod of internal solver Note: sometimes those are hyper-parameters of the solver, but if the solver is a Jaxopt solver they will be forwarded to the underlying function being optimized
 
- Return type
- OptStep