jaxopt.ProjectedGradient
- class jaxopt.ProjectedGradient(fun, projection, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
Projected gradient solver.
This solver is a convenience wrapper around
jaxopt.ProximalGradient
.- Parameters
fun (Callable) –
projection (Callable) –
value_and_grad (bool) –
has_aux (bool) –
stepsize (Union[float, Callable]) –
maxiter (int) –
maxls (int) –
tol (float) –
acceleration (bool) –
decrease_factor (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fun
a smooth function of the form
fun(parameters, *args, **kwargs)
, whereparameters
are the model parameters w.r.t. which we minimize the function and the rest are fixed auxiliary parameters.- Type
Callable
- projection
projection operator associated with the constraints. It should be of the form
projection(params, hyperparams_proj)
. Seejaxopt.projection
for examples.- Type
Callable
- value_and_grad
whether
fun
just returns the value (False) or both the value and gradient (True).- Type
bool
- has_aux
whether
fun
outputs auxiliary data or not. Ifhas_aux
is False,fun
is expected to bescalar-valued.
- If
has_aux
is True, then we have one of the following two cases.
If
value_and_grad
is False, the output should bevalue, aux = fun(...)
. Ifvalue_and_grad == True
, the output should be(value, aux), grad = fun(...)
. At each iteration of the algorithm, the auxiliary outputs are storedin
state.aux
.- Type
bool
- If
- stepsize
a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.
- Type
Union[float, Callable]
- maxiter
maximum number of projected gradient descent iterations.
- Type
int
- maxls
maximum number of iterations to use in the line search.
- Type
int
- tol
tolerance to use.
- Type
float
- acceleration
whether to use acceleration (also known as FISTA) or not.
- Type
bool
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- has_aux
whether function fun outputs one (False) or more values (True). When True it will be assumed by default that fun(…)[0] is the objective.
- Type
bool
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
- __init__(fun, projection, value_and_grad=False, has_aux=False, stepsize=0.0, maxiter=500, maxls=15, tol=0.001, acceleration=True, decrease_factor=0.5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
fun (Callable) –
projection (Callable) –
value_and_grad (bool) –
has_aux (bool) –
stepsize (Union[float, Callable]) –
maxiter (int) –
maxls (int) –
tol (float) –
acceleration (bool) –
decrease_factor (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(fun, projection[, value_and_grad, ...])attribute_names
()attribute_values
()init_state
(init_params[, hyperparams_proj])Initialize the parameters and state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(sol, hyperparams_proj, *args, ...)Optimality function mapping compatible with
@custom_root
.run
(init_params[, hyperparams_proj])Runs the optimization loop.
update
(params, state[, hyperparams_proj])Performs one iteration of projected gradient.
Attributes
decrease_factor
- init_state(init_params, hyperparams_proj=None, *args, **kwargs)[source]
Initialize the parameters and state.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.hyperparams_proj (
Optional
[Any
]) – pytree containing hyperparameters of projection.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
ProxGradState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(sol, hyperparams_proj, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, hyperparams_proj=None, *args, **kwargs)[source]
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
hyperparams_proj (Optional[Any]) –
- Return type
OptStep
- Returns
(params, state)
- update(params, state, hyperparams_proj=None, *args, **kwargs)[source]
Performs one iteration of projected gradient.
- Parameters
params (
Any
) – pytree containing the parameters.state (
NamedTuple
) – named tuple containing the solver state.hyperparams_proj (
Optional
[Any
]) – pytree containing hyperparameters of projection.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)