jaxopt.ArmijoSGD
- class jaxopt.ArmijoSGD(fun, value_and_grad=False, has_aux=False, aggressiveness=0.9, decrease_factor=0.8, increase_factor=1.5, reset_option='increase', momentum=0.0, max_stepsize=1.0, pre_update=None, maxiter=500, maxls=15, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
SGD with Armijo line search.
- This implementation assumes that the “interpolation property” holds, see for example Vaswani et al. 2019 (https://arxiv.org/abs/1905.09997):
the global optimum over D must also be a global optimum for any finite sample of D
This is typically achieved by overparametrized models (e.g neural networks) in classification tasks with separable classes, or on regression tasks without noise. In practice this algorithm works well outside this setting.
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
aggressiveness (float) –
decrease_factor (float) –
increase_factor (float) –
reset_option (str) –
momentum (float) –
max_stepsize (float) –
pre_update (Optional[Callable]) –
maxiter (int) –
maxls (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fun
a function of the form
fun(params, *args, **kwargs)
, whereparams
are parameters of the model,*args
and**kwargs
are additional arguments.- Type
Callable
- value_and_grad
whether
fun
just returns the value (False) or both the value and gradient (True).- Type
bool
- has_aux
whether
fun
outputs auxiliary data or not. Ifhas_aux
is False,fun
is expected to bescalar-valued.
- If
has_aux
is True, then we have one of the following two cases.
If
value_and_grad
is False, the output should bevalue, aux = fun(...)
. Ifvalue_and_grad == True
, the output should be(value, aux), grad = fun(...)
. At each iteration of the algorithm, the auxiliary outputs are storedin
state.aux
.- Type
bool
- If
- aggressiveness
controls “agressiveness” of optimizer. (default: 0.9) Bigger values encourage bigger stepsize. Must belong to open interval (0,1). If
aggressiveness>0.5
the learning_rate is guaranteed to be at least as big asmin(1/L, max_stepsize)
whereL
is the Lipschitz constant of the loss on the current batch.- Type
float
- decrease_factor
factor by which to decrease the stepsize during line search (default: 0.8).
- Type
float
- increase_factor
factor by which to increase the stepsize during line search (default: 1.5).
- Type
float
- reset_option
strategy to use for resetting the stepsize at each iteration (default: “increase”).
“conservative”: re-use previous stepsize, producing a non increasing sequence of stepsizes. Slow convergence.
“increase”: attempt to re-use previous stepsize multiplied by increase_factor. Cheap and efficient heuristic.
“goldstein”: re-use previous stepsize and increase until curvature condition is fulfilled. Higher runtime cost than “increase” but better theoretical guarantees.
- Type
str
- momentum
momentum parameter, 0 corresponding to no momentum.
- Type
float
- max_stepsize
a maximum step size to use. (default: 1.)
- Type
float
- pre_update
a function to execute before the solver’s update. The function signature must be
params, state = pre_update(params, state, *args, **kwargs)
.- Type
Optional[Callable]
- maxiter
maximum number of solver iterations.
- Type
int
- maxls
maximum number of steps in line search.
- Type
int
- tol
tolerance to use.
- Type
float
- verbose
whether to print error on every iteration or not. verbose=True will automatically disable jit.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
References
Vaswani, S., Mishkin, A., Laradji, I., Schmidt, M., Gidel, G. and Lacoste-Julien, S., 2019. Painless stochastic gradient: Interpolation, line-search, and convergence rates. Advances in Neural Information Processing Systems 32.
- __init__(fun, value_and_grad=False, has_aux=False, aggressiveness=0.9, decrease_factor=0.8, increase_factor=1.5, reset_option='increase', momentum=0.0, max_stepsize=1.0, pre_update=None, maxiter=500, maxls=15, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
aggressiveness (float) –
decrease_factor (float) –
increase_factor (float) –
reset_option (str) –
momentum (float) –
max_stepsize (float) –
pre_update (Optional[Callable]) –
maxiter (int) –
maxls (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(fun[, value_and_grad, has_aux, ...])attribute_names
()attribute_values
()init_state
(init_params, *args, **kwargs)Initialize the state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(params, *args, **kwargs)Optimality function mapping compatible with
@custom_root
.reset_stepsize
(stepsize)Return new step size for current step, according to reset_option.
run
(init_params, *args, **kwargs)Runs the optimization loop.
run_iterator
(init_params, iterator, *args, ...)Runs the optimization loop over an iterator.
update
(params, state, *args, **kwargs)Performs one iteration of the solver.
Attributes
- init_state(init_params, *args, **kwargs)[source]
Initialize the state.
- Parameters
init_params – pytree containing the initial parameters.
*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
ArmijoState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- run_iterator(init_params, iterator, *args, **kwargs)
Runs the optimization loop over an iterator.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.iterator – iterator generating data batches.
*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)
- update(params, state, *args, **kwargs)[source]
Performs one iteration of the solver.
- Parameters
params – pytree containing the parameters.
state – named tuple containing the solver state.
*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)