jaxopt.AndersonAcceleration
- class jaxopt.AndersonAcceleration(fixed_point_fun, history_size=5, mixing_frequency=1, beta=1.0, maxiter=100, tol=1e-05, ridge=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
Anderson acceleration.
- Parameters
fixed_point_fun (Callable) –
history_size (int) –
mixing_frequency (int) –
beta (float) –
maxiter (int) –
tol (float) –
ridge (float) –
has_aux (bool) –
verbose (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fixed_point_fun
a function
fixed_point_fun(x, *args, **kwargs)
returning a pytree with the same structure and type as x See the reference below for conditions that the function must fulfill in order to guarantee convergence. In particular, if the Banach fixed point theorem conditions hold, Anderson acceleration will converge.- Type
Callable
- history_size
size of history. Affect memory cost.
- Type
int
- mixing_frequency
frequency of Anderson updates. (default: 1). Only one every mixing_frequency updates uses Anderson, while the other updates use regular fixed point iterations.
- Type
int
- beta
momentum in Anderson updates. Default = 1.
- Type
float
- maxiter
maximum number of iterations.
- Type
int
- tol
tolerance (stoping criterion).
- Type
float
- ridge
ridge regularization in solver. Consider increasing this value if the solver returns
NaN
.- Type
float
- has_aux
wether fixed_point_fun returns additional data. (default: False) This additional data is not taken into account by the fixed point. The solver returns the aux associated to the last iterate (i.e the fixed point).
- Type
bool
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
bool
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”)
- Type
Union[str, bool]
References
Pollock, Sara, and Leo Rebholz. “Anderson acceleration for contractive and noncontractive operators.” arXiv preprint arXiv:1909.04638 (2019).
- __init__(fixed_point_fun, history_size=5, mixing_frequency=1, beta=1.0, maxiter=100, tol=1e-05, ridge=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
fixed_point_fun (Callable) –
history_size (int) –
mixing_frequency (int) –
beta (float) –
maxiter (int) –
tol (float) –
ridge (float) –
has_aux (bool) –
verbose (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(fixed_point_fun[, history_size, ...])attribute_names
()attribute_values
()init_state
(init_params, *args, **kwargs)Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(params, *args, **kwargs)Optimality function mapping compatible with
@custom_root
.run
(init_params, *args, **kwargs)Runs the optimization loop.
update
(params, state, *args, **kwargs)Performs one iteration of the Anderson acceleration.
Attributes
- init_state(init_params, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params – initial guess of the fixed point, pytree
*args – additional positional arguments to be passed to
fixed_point_fun
.**kwargs – additional keyword arguments to be passed to
fixed_point_fun
.
- Return type
AndersonState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- update(params, state, *args, **kwargs)[source]
Performs one iteration of the Anderson acceleration.
- Parameters
params (
Any
) – pytree containing the parameters.state (
NamedTuple
) – named tuple containing the solver state.*args – additional positional arguments to be passed to
fixed_point_fun
.**kwargs – additional keyword arguments to be passed to
fixed_point_fun
.
- Return type
OptStep
- Returns
(params, state)