jaxopt.ProximalGradient

class jaxopt.ProximalGradient(fun, prox=<function prox_none>, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Proximal gradient solver.

This solver minimizes:

objective(params, hyperparams_prox, *args, **kwargs) =
  fun(params, *args, **kwargs) + non_smooth(params, hyperparams_prox)
Parameters
  • fun (Callable) –

  • prox (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • stepsize (Union[float, Callable]) –

  • maxiter (int) –

  • maxls (int) –

  • tol (float) –

  • acceleration (bool) –

  • decrease_factor (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a smooth function of the form fun(x, *args, **kwargs).

Type

Callable

prox

proximity operator associated with the function non_smooth. It should be of the form prox(params, hyperparams_prox, scale=1.0). See jaxopt.prox for examples.

Type

Callable

Parameters
  • x (Any) –

  • hyperparams (Optional[Any]) –

  • scaling (float) –

Return type

Any

value_and_grad

whether fun just returns the value (False) or both the value and gradient (True).

Type

bool

has_aux

whether fun outputs auxiliary data or not. If has_aux is False, fun is expected to be

scalar-valued.

If has_aux is True, then we have one of the following

two cases.

If value_and_grad is False, the output should be value, aux = fun(...). If value_and_grad == True, the output should be (value, aux), grad = fun(...). At each iteration of the algorithm, the auxiliary outputs are stored

in state.aux.

Type

bool

stepsize

a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.

Type

Union[float, Callable]

maxiter

maximum number of proximal gradient descent iterations.

Type

int

maxls

maximum number of iterations to use in the line search.

Type

int

tol

tolerance to use.

Type

float

acceleration

whether to use acceleration (also known as FISTA) or not.

Type

bool

decrease_factor

factor by which to reduce the stepsize during line search.

Type

float

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

References

Beck, Amir, and Marc Teboulle. “A fast iterative shrinkage-thresholding algorithm for linear inverse problems.” SIAM imaging sciences (2009)

Nesterov, Yu. “Gradient methods for minimizing composite functions.” Mathematical Programming (2013).

__init__(fun, prox=<function prox_none>, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • prox (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • stepsize (Union[float, Callable]) –

  • maxiter (int) –

  • maxls (int) –

  • tol (float) –

  • acceleration (bool) –

  • decrease_factor (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun[, prox, value_and_grad, ...])

attribute_names()

attribute_values()

init_state(init_params, hyperparams_prox, ...)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(sol, hyperparams_prox, *args, ...)

Optimality function mapping compatible with @custom_root.

prox([hyperparams, scaling])

Proximal operator for \(g(x) = 0\), i.e., the identity function.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, hyperparams_prox, ...)

Performs one iteration of proximal gradient.

Attributes

acceleration

decrease_factor

has_aux

implicit_diff

implicit_diff_solve

jit

maxiter

maxls

stepsize

tol

unroll

value_and_grad

verbose

fun

init_state(init_params, hyperparams_prox, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • hyperparams_prox (Any) – pytree containing hyperparameters of prox.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

ProxGradState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(sol, hyperparams_prox, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

prox(hyperparams=None, scaling=1.0)

Proximal operator for \(g(x) = 0\), i.e., the identity function.

Since \(g(x) = 0\), the output is:

\[\underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2 = x\]
Parameters
  • x (Any) – input pytree.

  • hyperparams (Optional[Any]) – ignored.

  • scaling (float) – ignored.

Return type

Any

Returns

output pytree, with the same structure as x.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, hyperparams_prox, *args, **kwargs)[source]

Performs one iteration of proximal gradient.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

  • hyperparams_prox (Any) – pytree containing hyperparameters of prox.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)