jaxopt.PolyakSGD

class jaxopt.PolyakSGD(fun, value_and_grad=False, has_aux=False, max_stepsize=1.0, delta=0.0, momentum=0.0, pre_update=None, maxiter=500, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

SGD with Polyak step size.

This solver computes step sizes in an adaptive manner. If the computed step size at a given iteration is smaller than max_stepsize, it is accepted. Otherwise, max_stepsize is used. This ensures that the solver does not take over-confident steps. This is why max_stepsize is the most important hyper-parameter.

This implementation assumes that the interpolation property holds:

the global optimum over D must also be a global optimum for any finite sample of D

This is typically achieved by overparametrized models (e.g neural networks) in classification tasks with separable classes, or on regression tasks without noise.

Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • max_stepsize (float) –

  • delta (float) –

  • momentum (float) –

  • pre_update (Optional[Callable]) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a function of the form fun(params, *args, **kwargs), where params are parameters of the model, *args and **kwargs are additional arguments.

Type

Callable

has_aux

whether fun outputs auxiliary data or not. If has_aux is False, fun is expected to be

scalar-valued.

If has_aux is True, then we have one of the following

two cases.

If value_and_grad is False, the output should be value, aux = fun(...). If value_and_grad == True, the output should be (value, aux), grad = fun(...). At each iteration of the algorithm, the auxiliary outputs are stored

in state.aux.

Type

bool

max_stepsize

a maximum step size to use.

Type

float

delta

a value to add in the denominator of the update (default: 0).

Type

float

momentum

momentum parameter, 0 corresponding to no momentum.

Type

float

pre_update

a function to execute before the solver’s update. The function signature must be params, state = pre_update(params, state, *args, **kwargs).

Type

Optional[Callable]

maxiter

maximum number of solver iterations.

Type

int

tol

tolerance to use.

Type

float

verbose

whether to print error on every iteration or not. verbose=True will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

References

Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan. “Training neural networks for and by interpolation”. International Conference on Machine Learning, 2020. https://arxiv.org/abs/1906.05661

Loizou, Nicolas and Vaswani, Sharan and Laradji, Issam Hadj and Lacoste-Julien, Simon. “Stochastic polyak step-size for sgd: An adaptive learning rate for fast convergence”. International Conference on Artificial Intelligence and Statistics, 2021. https://arxiv.org/abs/2002.10542

__init__(fun, value_and_grad=False, has_aux=False, max_stepsize=1.0, delta=0.0, momentum=0.0, pre_update=None, maxiter=500, tol=0.001, verbose=0, implicit_diff=False, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • max_stepsize (float) –

  • delta (float) –

  • momentum (float) –

  • pre_update (Optional[Callable]) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun[, value_and_grad, has_aux, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

run_iterator(init_params, iterator, *args, ...)

Runs the optimization loop over an iterator.

update(params, state, *args, **kwargs)

Performs one iteration of the solver.

Attributes

delta

has_aux

implicit_diff

implicit_diff_solve

jit

max_stepsize

maxiter

momentum

pre_update

tol

unroll

value_and_grad

verbose

fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

PolyakSGDState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

run_iterator(init_params, iterator, *args, **kwargs)

Runs the optimization loop over an iterator.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • iterator – iterator generating data batches.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of the solver.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (PolyakSGDState) – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)