jaxopt.LBFGS
- class jaxopt.LBFGS(fun, value_and_grad=False, has_aux=False, maxiter=500, tol=0.001, stepsize=0.0, linesearch='zoom', linesearch_init='increase', stop_if_linesearch_fails=False, condition=None, maxls=15, decrease_factor=None, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, history_size=10, use_gamma=True, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)[source]
LBFGS solver.
Supports complex variables, see second reference.
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
maxiter (int) –
tol (float) –
stepsize (Union[float, Callable]) –
linesearch (str) –
linesearch_init (str) –
stop_if_linesearch_fails (bool) –
condition (Any) –
maxls (int) –
decrease_factor (Any) –
increase_factor (float) –
max_stepsize (float) –
min_stepsize (float) –
history_size (int) –
use_gamma (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
verbose (bool) –
- fun
a smooth function of the form
fun(x, *args, **kwargs)
.- Type
Callable
- value_and_grad
whether
fun
just returns the value (False) or both the value and gradient (True).- Type
bool
- has_aux
whether
fun
outputs auxiliary data or not. Ifhas_aux
is False,fun
is expected to be scalar-valued. Ifhas_aux
is True, then we have one of the following two cases. Ifvalue_and_grad
is False, the output should bevalue, aux = fun(...)
. Ifvalue_and_grad == True
, the output should be(value, aux), grad = fun(...)
. At each iteration of the algorithm, the auxiliary outputs are stored instate.aux
.- Type
bool
- maxiter
maximum number of proximal gradient descent iterations.
- Type
int
- tol
tolerance of the stopping criterion.
- Type
float
- stepsize
a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.
- Type
Union[float, Callable]
- linesearch
the type of line search to use: “backtracking” for backtracking line search, “zoom” for zoom line search or “hager-zhang” for Hager-Zhang line search.
- Type
str
- linesearch_init
strategy for line-search initialization. By default, it will use “increase”, which will increase the step-size by a factor of increase_factor at each iteration if the step-size is larger than min_stepsize, and set it to max_stepsize otherwise. Other choices are “max”, that initializes the step-size to max_stepsize at every iteration, and “current”, that uses the step-size from the previous iteration.
- Type
str
- stop_if_linesearch_fails
whether to stop iterations if the line search fails. When True, this matches the behavior of core JAX.
- Type
bool
- condition
Deprecated. Condition used to select the stepsize when using backtracking linesearch.
- Type
Any
- maxls
maximum number of iterations to use in the line search.
- Type
int
- decrease_factor
Deprecated. Factor by which to decrease the stepsize during line search when using backtracking linesearch (default: 0.8).
- Type
Any
- increase_factor
factor by which to increase the stepsize during line search (default: 1.5).
- Type
float
- max_stepsize
upper bound on stepsize.
- Type
float
- min_stepsize
lower bound on stepsize.
- Type
float
- history_size
size of the memory to use.
- Type
int
- use_gamma
whether to initialize the inverse Hessian approximation with gamma * I, where gamma is chosen following equation (7.20) of ‘Numerical Optimization’ (reference below). If use_gamma is set to False, the identity is used as initialization.
- Type
bool
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
bool
References
Jorge Nocedal and Stephen Wright. Numerical Optimization, second edition. Algorithm 7.5 (page 179).
Laurent Sorber, Marc van Barel, and Lieven de Lathauwer. Unconstrained Optimization of Real Functions in Complex Variables. SIAM J. Optim., Vol. 22, No. 3, pp. 879-898
- __init__(fun, value_and_grad=False, has_aux=False, maxiter=500, tol=0.001, stepsize=0.0, linesearch='zoom', linesearch_init='increase', stop_if_linesearch_fails=False, condition=None, maxls=15, decrease_factor=None, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, history_size=10, use_gamma=True, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
maxiter (int) –
tol (float) –
stepsize (Union[float, Callable]) –
linesearch (str) –
linesearch_init (str) –
stop_if_linesearch_fails (bool) –
condition (Optional[Any]) –
maxls (int) –
decrease_factor (Optional[Any]) –
increase_factor (float) –
max_stepsize (float) –
min_stepsize (float) –
history_size (int) –
use_gamma (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
verbose (bool) –
- Return type
None
Methods
__init__
(fun[, value_and_grad, has_aux, ...])attribute_names
()attribute_values
()init_state
(init_params, *args, **kwargs)Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(params, *args, **kwargs)Optimality function mapping compatible with
@custom_root
.run
(init_params, *args, **kwargs)Runs the optimization loop.
update
(params, state, *args, **kwargs)Performs one iteration of LBFGS.
Attributes
- init_state(init_params, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
LbfgsState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- update(params, state, *args, **kwargs)[source]
Performs one iteration of LBFGS.
- Parameters
params (
Any
) – pytree containing the parameters.state (
LbfgsState
) – named tuple containing the solver state.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)