jaxopt.AndersonAcceleration

class jaxopt.AndersonAcceleration(fixed_point_fun, history_size=5, mixing_frequency=1, beta=1.0, maxiter=100, tol=1e-05, ridge=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Anderson acceleration.

Parameters
  • fixed_point_fun (Callable) –

  • history_size (int) –

  • mixing_frequency (int) –

  • beta (float) –

  • maxiter (int) –

  • tol (float) –

  • ridge (float) –

  • has_aux (bool) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fixed_point_fun

a function fixed_point_fun(x, *args, **kwargs) returning a pytree with the same structure and type as x See the reference below for conditions that the function must fulfill in order to guarantee convergence. In particular, if the Banach fixed point theorem conditions hold, Anderson acceleration will converge.

Type

Callable

history_size

size of history. Affect memory cost.

Type

int

mixing_frequency

frequency of Anderson updates. (default: 1). Only one every mixing_frequency updates uses Anderson, while the other updates use regular fixed point iterations.

Type

int

beta

momentum in Anderson updates. Default = 1.

Type

float

maxiter

maximum number of iterations.

Type

int

tol

tolerance (stoping criterion).

Type

float

ridge

ridge regularization in solver. Consider increasing this value if the solver returns NaN.

Type

float

has_aux

wether fixed_point_fun returns additional data. (default: False) This additional data is not taken into account by the fixed point. The solver returns the aux associated to the last iterate (i.e the fixed point).

Type

bool

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”)

Type

Union[str, bool]

References

Pollock, Sara, and Leo Rebholz. “Anderson acceleration for contractive and noncontractive operators.” arXiv preprint arXiv:1909.04638 (2019).

__init__(fixed_point_fun, history_size=5, mixing_frequency=1, beta=1.0, maxiter=100, tol=1e-05, ridge=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • fixed_point_fun (Callable) –

  • history_size (int) –

  • mixing_frequency (int) –

  • beta (float) –

  • maxiter (int) –

  • tol (float) –

  • ridge (float) –

  • has_aux (bool) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fixed_point_fun[, history_size, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of the Anderson acceleration.

Attributes

beta

has_aux

history_size

implicit_diff

implicit_diff_solve

jit

maxiter

mixing_frequency

ridge

tol

unroll

verbose

fixed_point_fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params – initial guess of the fixed point, pytree

  • *args – additional positional arguments to be passed to fixed_point_fun.

  • **kwargs – additional keyword arguments to be passed to fixed_point_fun.

Return type

AndersonState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of the Anderson acceleration.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fixed_point_fun.

  • **kwargs – additional keyword arguments to be passed to fixed_point_fun.

Return type

OptStep

Returns

(params, state)