jaxopt.FixedPointIteration

class jaxopt.FixedPointIteration(fixed_point_fun, maxiter=100, tol=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Fixed point iteration method. .. attribute:: fixed_point_fun

a function fixed_point_fun(x, *args, **kwargs) returning a pytree with the same structure and type as x The function should fulfill the Banach fixed-point theorem’s assumptions. Otherwise convergence is not guaranteed.

type

Callable

Parameters
  • fixed_point_fun (Callable) –

  • maxiter (int) –

  • tol (float) –

  • has_aux (bool) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

maxiter

maximum number of iterations.

Type

int

tol

tolerance (stopping criterion)

Type

float

has_aux

wether fixed_point_fun returns additional data. (default: False) if True, the fixed is computed only with respect to first element of the sequence returned. Other elements are carried during computation.

Type

bool

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”)

Type

Union[str, bool]

References

https://en.wikipedia.org/wiki/Fixed-point_iteration

__init__(fixed_point_fun, maxiter=100, tol=1e-05, has_aux=False, verbose=False, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • fixed_point_fun (Callable) –

  • maxiter (int) –

  • tol (float) –

  • has_aux (bool) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fixed_point_fun[, maxiter, tol, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of the fixed point iteration method. :type params: Any :param params: pytree containing the parameters. :type state: NamedTuple :param state: named tuple containing the solver state. :param *args: additional positional arguments to be passed to fixed_point_fun. :param **kwargs: additional keyword arguments to be passed to fixed_point_fun.

Attributes

has_aux

implicit_diff

implicit_diff_solve

jit

maxiter

tol

unroll

verbose

fixed_point_fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params – initial guess of the fixed point, pytree

  • *args – additional positional arguments to be passed to optimality_fun.

  • **kwargs – additional keyword arguments to be passed to optimality_fun.

Return type

FixedPointState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of the fixed point iteration method. :type params: Any :param params: pytree containing the parameters. :type state: NamedTuple :param state: named tuple containing the solver state. :param *args: additional positional arguments to be passed to

fixed_point_fun.

Parameters
  • **kwargs – additional keyword arguments to be passed to fixed_point_fun.

  • params (Any) –

  • state (NamedTuple) –

Return type

OptStep

Returns

(params, state)