jaxopt.PolyakSGD
- class jaxopt.PolyakSGD(fun, value_and_grad=False, has_aux=False, max_stepsize=1.0, delta=0.0, momentum=0.0, pre_update=None, maxiter=500, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
SGD with Polyak step size.
This solver computes step sizes in an adaptive manner. If the computed step size at a given iteration is smaller than
max_stepsize
, it is accepted. Otherwise,max_stepsize
is used. This ensures that the solver does not take over-confident steps. This is whymax_stepsize
is the most important hyper-parameter.- This implementation assumes that the interpolation property holds:
the global optimum over D must also be a global optimum for any finite sample of D
This is typically achieved by overparametrized models (e.g neural networks) in classification tasks with separable classes, or on regression tasks without noise.
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
max_stepsize (float) –
delta (float) –
momentum (float) –
pre_update (Optional[Callable]) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fun
a function of the form
fun(params, *args, **kwargs)
, whereparams
are parameters of the model,*args
and**kwargs
are additional arguments.- Type
Callable
- has_aux
whether
fun
outputs auxiliary data or not. Ifhas_aux
is False,fun
is expected to bescalar-valued.
- If
has_aux
is True, then we have one of the following two cases.
If
value_and_grad
is False, the output should bevalue, aux = fun(...)
. Ifvalue_and_grad == True
, the output should be(value, aux), grad = fun(...)
. At each iteration of the algorithm, the auxiliary outputs are storedin
state.aux
.- Type
bool
- If
- max_stepsize
a maximum step size to use.
- Type
float
- delta
a value to add in the denominator of the update (default: 0).
- Type
float
- momentum
momentum parameter, 0 corresponding to no momentum.
- Type
float
- pre_update
a function to execute before the solver’s update. The function signature must be
params, state = pre_update(params, state, *args, **kwargs)
.- Type
Optional[Callable]
- maxiter
maximum number of solver iterations.
- Type
int
- tol
tolerance to use.
- Type
float
- verbose
whether to print error on every iteration or not. verbose=True will automatically disable jit.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
References
Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan. “Training neural networks for and by interpolation”. International Conference on Machine Learning, 2020. https://arxiv.org/abs/1906.05661
Loizou, Nicolas and Vaswani, Sharan and Laradji, Issam Hadj and Lacoste-Julien, Simon. “Stochastic polyak step-size for sgd: An adaptive learning rate for fast convergence”. International Conference on Artificial Intelligence and Statistics, 2021. https://arxiv.org/abs/2002.10542
- __init__(fun, value_and_grad=False, has_aux=False, max_stepsize=1.0, delta=0.0, momentum=0.0, pre_update=None, maxiter=500, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
max_stepsize (float) –
delta (float) –
momentum (float) –
pre_update (Optional[Callable]) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(fun[, value_and_grad, has_aux, ...])attribute_names
()attribute_values
()init_state
(init_params, *args, **kwargs)Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(params, *args, **kwargs)Optimality function mapping compatible with
@custom_root
.run
(init_params, *args, **kwargs)Runs the optimization loop.
run_iterator
(init_params, iterator, *args, ...)Runs the optimization loop over an iterator.
update
(params, state, *args, **kwargs)Performs one iteration of the solver.
Attributes
value_and_grad
- init_state(init_params, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
PolyakSGDState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- run_iterator(init_params, iterator, *args, **kwargs)
Runs the optimization loop over an iterator.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.iterator – iterator generating data batches.
*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)
- update(params, state, *args, **kwargs)[source]
Performs one iteration of the solver.
- Parameters
params (
Any
) – pytree containing the parameters.state (
PolyakSGDState
) – named tuple containing the solver state.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)