jaxopt.GradientDescent
- class jaxopt.GradientDescent(fun, prox=<function prox_none>, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
Gradient Descent solver.
- Parameters
fun (Callable) –
prox (Callable) –
value_and_grad (bool) –
has_aux (bool) –
stepsize (Union[float, Callable]) –
maxiter (int) –
maxls (int) –
tol (float) –
acceleration (bool) –
decrease_factor (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fun
a smooth function of the form
fun(parameters, *args, **kwargs), whereparametersare the model parameters w.r.t. which we minimize the function and the rest are fixed auxiliary parameters.- Type
Callable
- value_and_grad
whether
funjust returns the value (False) or both the value and gradient (True).- Type
bool
- has_aux
whether
funoutputs auxiliary data or not. Ifhas_auxis False,funis expected to bescalar-valued.
- If
has_auxis True, then we have one of the following two cases.
If
value_and_gradis False, the output should bevalue, aux = fun(...). Ifvalue_and_grad == True, the output should be(value, aux), grad = fun(...). At each iteration of the algorithm, the auxiliary outputs are storedin
state.aux.- Type
bool
- If
- stepsize
a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.
- Type
Union[float, Callable]
- maxiter
maximum number of proximal gradient descent iterations.
- Type
int
- maxls
maximum number of iterations to use in the line search.
- Type
int
- tol
tolerance to use.
- Type
float
- acceleration
whether to use acceleration (also known as FISTA) or not.
- Type
bool
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- __init__(fun, prox=<function prox_none>, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
fun (Callable) –
prox (Callable) –
value_and_grad (bool) –
has_aux (bool) –
stepsize (Union[float, Callable]) –
maxiter (int) –
maxls (int) –
tol (float) –
acceleration (bool) –
decrease_factor (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__(fun[, prox, value_and_grad, ...])attribute_names()attribute_values()init_state(init_params, *args, **kwargs)Initialize the solver state.
l2_optimality_error(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun(params, *args, **kwargs)Optimality function mapping compatible with
@custom_root.prox([hyperparams, scaling])Proximal operator for \(g(x) = 0\), i.e., the identity function.
run(init_params, *args, **kwargs)Runs the optimization loop.
update(params, state, *args, **kwargs)Performs one iteration of gradient descent.
Attributes
decrease_factorjitunroll- init_state(init_params, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Any) – pytree containing the initial parameters.*args – additional positional arguments to be passed to
fun.**kwargs – additional keyword arguments to be passed to
fun.
- Return type
ProxGradState- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root.
- prox(hyperparams=None, scaling=1.0)
Proximal operator for \(g(x) = 0\), i.e., the identity function.
Since \(g(x) = 0\), the output is:
\[\underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2 = x\]- Parameters
x (
Any) – input pytree.hyperparams (
Optional[Any]) – ignored.scaling (
float) – ignored.
- Return type
Any- Returns
output pytree, with the same structure as
x.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep- Returns
(params, state)
- update(params, state, *args, **kwargs)[source]
Performs one iteration of gradient descent.
- Parameters
params (
Any) – pytree containing the parameters.state (
NamedTuple) – named tuple containing the solver state.*args – additional positional arguments to be passed to
fun.**kwargs – additional keyword arguments to be passed to
fun.
- Return type
OptStep- Returns
(params, state)