jaxopt.MirrorDescent
- class jaxopt.MirrorDescent(fun, projection_grad, stepsize, maxiter=500, tol=0.01, verbose=0, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]
Mirror descent solver.
where fun is smooth with convex domain.
- The stopping criterion is:
||x - projection_grad(x, g, 1.0, hyperparams_proj)||_2 <= tol,
where
g = grad(fun)(x, *args, **kwargs)
.- Parameters
fun (Callable) –
projection_grad (Optional[Callable]) –
stepsize (Union[float, Callable]) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fun
a smooth function of the form
fun(x, *args, **kwargs)
.- Type
Callable
- projection_grad
a function of the form
projection_grad(x, g, stepsize, hyperparams_proj)
representing the mirror descent update for iterate x and gradient g. Optionally, it can be instantiated from a projection and mapping function (mirror map) using the method make_projection_grad.- Type
Optional[Callable]
- stepsize
a stepsize to use, or a callable specifying the stepsize to use at each iteration.
- Type
Union[float, Callable]
- maxiter
maximum number of mirror descent iterations.
- Type
int
- tol
tolerance to use.
- Type
float
- verbose
whether to print error on every iteration or not. verbose=True will automatically disable jit.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- has_aux
whether function fun outputs one (False) or more values (True). When True it will be assumed by default that fun(…)[0] is the objective.
- Type
bool
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
References
Nemirovskij, Arkadij Semenovič, and David Borisovich Yudin. “Problem complexity and method efficiency in optimization.” J. Wiley @ Sons, New York(1983).
- __init__(fun, projection_grad, stepsize, maxiter=500, tol=0.01, verbose=0, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
- Parameters
fun (Callable) –
projection_grad (Optional[Callable]) –
stepsize (Union[float, Callable]) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(fun, projection_grad, stepsize[, ...])attribute_names
()attribute_values
()init_state
(init_params, hyperparams_proj, ...)Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
make_projection_grad
(projection, mapping_fun)Instantiates projection_grad argument from projection and mirror map.
optimality_fun
(sol, hyperparams_proj, *args, ...)Optimality function mapping compatible with
@custom_root
.run
(init_params[, hyperparams_proj])Runs the optimization loop.
update
(params, state, hyperparams_proj, ...)Performs one iteration of mirror descent.
Attributes
- init_state(init_params, hyperparams_proj, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.hyperparams_proj (Any) –
- Return type
OptStep
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- static make_projection_grad(projection, mapping_fun)[source]
Instantiates projection_grad argument from projection and mirror map.
- Parameters
projection (
Callable
) – projection operator of the formprojection(x, hyperparams_proj)
, typicallyargmin_z D_{gen_fun}(z, mapping_fun^{-1}(y))
.mapping_fun (
Callable
) – the mirror map, typically of the formmapping_fun = grad(gen_fun)
, where gen_fun is the generating function of the Bregman divergence.
- Return type
Callable
- Returns
A function projection_grad(x, g, stepsize, hyperparams_proj) representing the mirror descent update for iterate x and gradient g.
- optimality_fun(sol, hyperparams_proj, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, hyperparams_proj=None, *args, **kwargs)[source]
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
hyperparams_proj (Optional[Any]) –
- Return type
OptStep
- Returns
(params, state)
- update(params, state, hyperparams_proj, *args, **kwargs)[source]
Performs one iteration of mirror descent.
- Parameters
params (
Any
) – pytree containing the parameters.state (
NamedTuple
) – named tuple containing the solver state.hyperparams_proj (
Any
) – pytree containing hyperparameters of projection.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)