jaxopt.OptaxSolver

class jaxopt.OptaxSolver(fun, opt, value_and_grad=False, pre_update=None, maxiter=500, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]

Optax solver.

Parameters
  • fun (Callable) –

  • opt (NamedTuple) –

  • value_and_grad (bool) –

  • pre_update (Optional[Callable]) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a function of the form fun(params, *args, **kwargs), where params are parameters of the model, *args and **kwargs are additional arguments.

Type

Callable

opt

the optimizer to use, an optax.GradientTransformation, which is just a NamedTuple with init and update functions.

Type

NamedTuple

value_and_grad

whether fun just returns the value (False) or both the value and gradient (True).

Type

bool

has_aux

whether fun outputs auxiliary data or not. If has_aux is False, fun is expected to be

scalar-valued.

If has_aux is True, then we have one of the following

two cases.

If value_and_grad is False, the output should be value, aux = fun(...). If value_and_grad == True, the output should be (value, aux), grad = fun(...). At each iteration of the algorithm, the auxiliary outputs are stored

in state.aux.

Type

bool

pre_update

a function to execute before Optax’s update. The function signature must be params, state = pre_update(params, state, *args, **kwargs).

Type

Optional[Callable]

maxiter

maximum number of solver iterations.

Type

int

tol

tolerance to use.

Type

float

verbose

whether to print error on every iteration or not. verbose=True will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

__init__(fun, opt, value_and_grad=False, pre_update=None, maxiter=500, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • opt (NamedTuple) –

  • value_and_grad (bool) –

  • pre_update (Optional[Callable]) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun, opt[, value_and_grad, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

run_iterator(init_params, iterator, *args, ...)

Runs the optimization loop over an iterator.

update(params, state, *args, **kwargs)

Performs one iteration of the optax solver.

Attributes

has_aux

implicit_diff

implicit_diff_solve

jit

maxiter

pre_update

tol

unroll

value_and_grad

verbose

fun

opt

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptaxState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

run_iterator(init_params, iterator, *args, **kwargs)

Runs the optimization loop over an iterator.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • iterator – iterator generating data batches.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of the optax solver.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)