jaxopt.MirrorDescent

class jaxopt.MirrorDescent(fun, projection_grad, stepsize, maxiter=500, tol=0.01, verbose=0, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]

Mirror descent solver.

This solver minimizes:

argmin_x fun(x, *args, **kwargs),

where fun is smooth with convex domain.

The stopping criterion is:

||x - projection_grad(x, g, 1.0, hyperparams_proj)||_2 <= tol,

where g = grad(fun)(x, *args, **kwargs).

Parameters
  • fun (Callable) –

  • projection_grad (Optional[Callable]) –

  • stepsize (Union[float, Callable]) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a smooth function of the form fun(x, *args, **kwargs).

Type

Callable

projection_grad

a function of the form projection_grad(x, g, stepsize, hyperparams_proj) representing the mirror descent update for iterate x and gradient g. Optionally, it can be instantiated from a projection and mapping function (mirror map) using the method make_projection_grad.

Type

Optional[Callable]

stepsize

a stepsize to use, or a callable specifying the stepsize to use at each iteration.

Type

Union[float, Callable]

maxiter

maximum number of mirror descent iterations.

Type

int

tol

tolerance to use.

Type

float

verbose

whether to print error on every iteration or not. verbose=True will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

has_aux

whether function fun outputs one (False) or more values (True). When True it will be assumed by default that fun(…)[0] is the objective.

Type

bool

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

References

Nemirovskij, Arkadij Semenovič, and David Borisovich Yudin. “Problem complexity and method efficiency in optimization.” J. Wiley @ Sons, New York(1983).

__init__(fun, projection_grad, stepsize, maxiter=500, tol=0.01, verbose=0, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • projection_grad (Optional[Callable]) –

  • stepsize (Union[float, Callable]) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun, projection_grad, stepsize[, ...])

attribute_names()

attribute_values()

init_state(init_params, hyperparams_proj, ...)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

make_projection_grad(projection, mapping_fun)

Instantiates projection_grad argument from projection and mirror map.

optimality_fun(sol, hyperparams_proj, *args, ...)

Optimality function mapping compatible with @custom_root.

run(init_params[, hyperparams_proj])

Runs the optimization loop.

update(params, state, hyperparams_proj, ...)

Performs one iteration of mirror descent.

Attributes

has_aux

implicit_diff

implicit_diff_solve

jit

maxiter

tol

unroll

verbose

fun

projection_grad

stepsize

init_state(init_params, hyperparams_proj, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • hyperparams_proj (Any) –

Return type

OptStep

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

static make_projection_grad(projection, mapping_fun)[source]

Instantiates projection_grad argument from projection and mirror map.

Parameters
  • projection (Callable) – projection operator of the form projection(x, hyperparams_proj), typically argmin_z D_{gen_fun}(z, mapping_fun^{-1}(y)).

  • mapping_fun (Callable) – the mirror map, typically of the form mapping_fun = grad(gen_fun), where gen_fun is the generating function of the Bregman divergence.

Return type

Callable

Returns

A function projection_grad(x, g, stepsize, hyperparams_proj) representing the mirror descent update for iterate x and gradient g.

optimality_fun(sol, hyperparams_proj, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, hyperparams_proj=None, *args, **kwargs)[source]

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

  • hyperparams_proj (Optional[Any]) –

Return type

OptStep

Returns

(params, state)

update(params, state, hyperparams_proj, *args, **kwargs)[source]

Performs one iteration of mirror descent.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

  • hyperparams_proj (Any) – pytree containing hyperparameters of projection.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)