jaxopt.GradientDescent

class jaxopt.GradientDescent(fun, prox=<function prox_none>, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Gradient Descent solver.

Parameters
  • fun (Callable) –

  • prox (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • stepsize (Union[float, Callable]) –

  • maxiter (int) –

  • maxls (int) –

  • tol (float) –

  • acceleration (bool) –

  • decrease_factor (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a smooth function of the form fun(parameters, *args, **kwargs), where parameters are the model parameters w.r.t. which we minimize the function and the rest are fixed auxiliary parameters.

Type

Callable

value_and_grad

whether fun just returns the value (False) or both the value and gradient (True).

Type

bool

has_aux

whether fun outputs auxiliary data or not. If has_aux is False, fun is expected to be

scalar-valued.

If has_aux is True, then we have one of the following

two cases.

If value_and_grad is False, the output should be value, aux = fun(...). If value_and_grad == True, the output should be (value, aux), grad = fun(...). At each iteration of the algorithm, the auxiliary outputs are stored

in state.aux.

Type

bool

stepsize

a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.

Type

Union[float, Callable]

maxiter

maximum number of proximal gradient descent iterations.

Type

int

maxls

maximum number of iterations to use in the line search.

Type

int

tol

tolerance to use.

Type

float

acceleration

whether to use acceleration (also known as FISTA) or not.

Type

bool

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

__init__(fun, prox=<function prox_none>, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • prox (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • stepsize (Union[float, Callable]) –

  • maxiter (int) –

  • maxls (int) –

  • tol (float) –

  • acceleration (bool) –

  • decrease_factor (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun[, prox, value_and_grad, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

prox([hyperparams, scaling])

Proximal operator for \(g(x) = 0\), i.e., the identity function.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of gradient descent.

Attributes

acceleration

decrease_factor

has_aux

implicit_diff

implicit_diff_solve

jit

maxiter

maxls

stepsize

tol

unroll

value_and_grad

verbose

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

ProxGradState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

prox(hyperparams=None, scaling=1.0)

Proximal operator for \(g(x) = 0\), i.e., the identity function.

Since \(g(x) = 0\), the output is:

\[\underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2 = x\]
Parameters
  • x (Any) – input pytree.

  • hyperparams (Optional[Any]) – ignored.

  • scaling (float) – ignored.

Return type

Any

Returns

output pytree, with the same structure as x.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of gradient descent.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)