jaxopt.AndersonWrapper

class jaxopt.AndersonWrapper(solver, history_size=5, mixing_frequency=None, beta=1.0, ridge=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Wrapper for accelerating JAXopt solvers.

Note that the internal solver state can be accessed via the aux attribute of AndersonState.

Parameters
  • solver (IterativeSolver) –

  • history_size (int) –

  • mixing_frequency (int) –

  • beta (float) –

  • ridge (float) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

solver

solver object to accelerate. Must exhibit init() and update() methods.

Type

jaxopt._src.base.IterativeSolver

history_size

size of history. Affect memory cost. (default: 5).

Type

int

mixing_frequency

frequency of Anderson updates. (default: history_size). Only one every mixing_frequency updates uses Anderson, while the other updates use regular fixed point iterations.

Type

int

beta

momentum in Anderson updates. (default: 1).

Type

float

ridge

ridge regularization in solver. Consider increasing this value if the solver returns NaN.

Type

float

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”)

Type

Union[str, bool]

__init__(solver, history_size=5, mixing_frequency=None, beta=1.0, ridge=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • solver (IterativeSolver) –

  • history_size (int) –

  • mixing_frequency (Optional[int]) –

  • beta (float) –

  • ridge (float) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(solver[, history_size, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

rtype

AndersonWrapperState

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Perform one step of Anderson acceleration over the internal solver update.

Attributes

beta

history_size

implicit_diff

implicit_diff_solve

jit

mixing_frequency

ridge

unroll

verbose

solver

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Perform one step of Anderson acceleration over the internal solver update.

The reset_state attribute is used to update the internal solver state after the Anderson step.

Parameters
  • params – parameters optimized by solver. Only its pytree structure matters (content unused).

  • state – AndersonWrapperState Crucially, state.params_history and state.residuals_history are the sequences used to generate next iterate. Note: state.solver_state is the internal solver state.

  • args – additional parameters passed to update method of internal solver Note: sometimes those are hyper-parameters of the solver, but if the solver is a Jaxopt solver they will be forwarded to the underlying function being optimized

  • kwargs – additional parameters passed to update method of internal solver Note: sometimes those are hyper-parameters of the solver, but if the solver is a Jaxopt solver they will be forwarded to the underlying function being optimized

Return type

OptStep