jaxopt.ArmijoSGD

class jaxopt.ArmijoSGD(fun, value_and_grad=False, has_aux=False, aggressiveness=0.9, decrease_factor=0.8, increase_factor=1.5, reset_option='increase', momentum=0.0, max_stepsize=1.0, pre_update=None, maxiter=500, maxls=15, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

SGD with Armijo line search.

This implementation assumes that the “interpolation property” holds, see for example Vaswani et al. 2019 (https://arxiv.org/abs/1905.09997):

the global optimum over D must also be a global optimum for any finite sample of D

This is typically achieved by overparametrized models (e.g neural networks) in classification tasks with separable classes, or on regression tasks without noise. In practice this algorithm works well outside this setting.

Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • aggressiveness (float) –

  • decrease_factor (float) –

  • increase_factor (float) –

  • reset_option (str) –

  • momentum (float) –

  • max_stepsize (float) –

  • pre_update (Optional[Callable]) –

  • maxiter (int) –

  • maxls (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a function of the form fun(params, *args, **kwargs), where params are parameters of the model, *args and **kwargs are additional arguments.

Type

Callable

value_and_grad

whether fun just returns the value (False) or both the value and gradient (True).

Type

bool

has_aux

whether fun outputs auxiliary data or not. If has_aux is False, fun is expected to be

scalar-valued.

If has_aux is True, then we have one of the following

two cases.

If value_and_grad is False, the output should be value, aux = fun(...). If value_and_grad == True, the output should be (value, aux), grad = fun(...). At each iteration of the algorithm, the auxiliary outputs are stored

in state.aux.

Type

bool

aggressiveness

controls “agressiveness” of optimizer. (default: 0.9) Bigger values encourage bigger stepsize. Must belong to open interval (0,1). If aggressiveness>0.5 the learning_rate is guaranteed to be at least as big as min(1/L, max_stepsize) where L is the Lipschitz constant of the loss on the current batch.

Type

float

decrease_factor

factor by which to decrease the stepsize during line search (default: 0.8).

Type

float

increase_factor

factor by which to increase the stepsize during line search (default: 1.5).

Type

float

reset_option

strategy to use for resetting the stepsize at each iteration (default: “increase”).

  • “conservative”: re-use previous stepsize, producing a non increasing sequence of stepsizes. Slow convergence.

  • “increase”: attempt to re-use previous stepsize multiplied by increase_factor. Cheap and efficient heuristic.

  • “goldstein”: re-use previous stepsize and increase until curvature condition is fulfilled. Higher runtime cost than “increase” but better theoretical guarantees.

Type

str

momentum

momentum parameter, 0 corresponding to no momentum.

Type

float

max_stepsize

a maximum step size to use. (default: 1.)

Type

float

pre_update

a function to execute before the solver’s update. The function signature must be params, state = pre_update(params, state, *args, **kwargs).

Type

Optional[Callable]

maxiter

maximum number of solver iterations.

Type

int

maxls

maximum number of steps in line search.

Type

int

tol

tolerance to use.

Type

float

verbose

whether to print error on every iteration or not. verbose=True will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

References

Vaswani, S., Mishkin, A., Laradji, I., Schmidt, M., Gidel, G. and Lacoste-Julien, S., 2019. Painless stochastic gradient: Interpolation, line-search, and convergence rates. Advances in Neural Information Processing Systems 32.

__init__(fun, value_and_grad=False, has_aux=False, aggressiveness=0.9, decrease_factor=0.8, increase_factor=1.5, reset_option='increase', momentum=0.0, max_stepsize=1.0, pre_update=None, maxiter=500, maxls=15, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • aggressiveness (float) –

  • decrease_factor (float) –

  • increase_factor (float) –

  • reset_option (str) –

  • momentum (float) –

  • max_stepsize (float) –

  • pre_update (Optional[Callable]) –

  • maxiter (int) –

  • maxls (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun[, value_and_grad, has_aux, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

reset_stepsize(stepsize)

Return new step size for current step, according to reset_option.

run(init_params, *args, **kwargs)

Runs the optimization loop.

run_iterator(init_params, iterator, *args, ...)

Runs the optimization loop over an iterator.

update(params, state, *args, **kwargs)

Performs one iteration of the solver.

Attributes

aggressiveness

decrease_factor

has_aux

implicit_diff

implicit_diff_solve

increase_factor

jit

max_stepsize

maxiter

maxls

momentum

pre_update

reset_option

tol

unroll

value_and_grad

verbose

fun

init_state(init_params, *args, **kwargs)[source]

Initialize the state.

Parameters
  • init_params – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

ArmijoState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

reset_stepsize(stepsize)[source]

Return new step size for current step, according to reset_option.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

run_iterator(init_params, iterator, *args, **kwargs)

Runs the optimization loop over an iterator.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • iterator – iterator generating data batches.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of the solver.

Parameters
  • params – pytree containing the parameters.

  • state – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)