jaxopt.LBFGSB
- class jaxopt.LBFGSB(fun, value_and_grad=False, has_aux=False, maxiter=50, tol=0.001, stepsize=0.0, linesearch='zoom', linesearch_init='increase', stop_if_linesearch_fails=False, condition=None, maxls=20, decrease_factor=None, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, theta=1.0, history_size=10, use_gamma=True, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)[source]
L-BFGS-B solver.
L-BFGS-B is a version of L-BFGS that incorporates box constraints on variables.
- Parameters
fun (Callable) –
value_and_grad (Union[bool, Callable]) –
has_aux (bool) –
maxiter (int) –
tol (float) –
stepsize (Union[float, Callable[[Any], float]]) –
linesearch (str) –
linesearch_init (str) –
stop_if_linesearch_fails (bool) –
condition (Any) –
maxls (int) –
decrease_factor (Any) –
increase_factor (float) –
max_stepsize (float) –
min_stepsize (float) –
theta (float) –
history_size (int) –
use_gamma (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable[[Any], Any]]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
verbose (bool) –
- fun
a smooth function of the form
fun(x, *args, **kwargs)
.- Type
Callable
- value_and_grad
whether
fun
just returns the value (False) or both the value and gradient (True). See base.make_funs_with_aux for details.- Type
Union[bool, Callable]
- has_aux
whether
fun
outputs auxiliary data or not. Ifhas_aux
is False,fun
is expected to be scalar-valued. Ifhas_aux
is True, then we have one of the following two cases. Ifvalue_and_grad
is False, the output should bevalue, aux = fun(...)
. Ifvalue_and_grad == True
, the output should be(value, aux), grad = fun(...)
. At each iteration of the algorithm, the auxiliary outputs are stored instate.aux
.- Type
bool
- maxiter
maximum number of proximal gradient descent iterations.
- Type
int
- tol
tolerance of the stopping criterion.
- Type
float
- stepsize
a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.
- Type
Union[float, Callable[[Any], float]]
- linesearch_init
strategy for line-search initialization. By default, it will use “increase”, which will increased the step-size by a factor of increase_factor at each iteration if the step-size is larger than min_stepsize, and set it to max_stepsize otherwise. Other choices are “max”, that initializes the step-size to max_stepsize at every iteration, and “current”, that uses the step-size from the previous iteration.
- Type
str
- stop_if_linesearch_fails
whether to stop iterations if the line search fails. When True, this matches the behavior of core JAX.
- Type
bool
- condition
Deprecated. Condition used to select the stepsize when using backtracking linesearch
- Type
Any
- maxls
maximum number of iterations to use in the line search.
- Type
int
- decrease_factor
Deprecated. factor by which to decrease the stepsize during backtracking line search (default: 0.8).
- Type
Any
- increase_factor
factor by which to increase the stepsize during line search (default: 1.5).
- Type
float
- max_stepsize
upper bound on stepsize.
- Type
float
- min_stepsize
lower bound on stepsize.
- Type
float
- history_size
size of the memory to use.
- Type
int
- use_gamma
whether to initialize the Hessian approximation with gamma * theta, where gamma is chosen following equation (7.20) of ‘Numerical Optimization’ [2]. If use_gamma is set to False, theta is used as initialization.
- Type
bool
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable[[Any], Any]]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
bool
- __init__(fun, value_and_grad=False, has_aux=False, maxiter=50, tol=0.001, stepsize=0.0, linesearch='zoom', linesearch_init='increase', stop_if_linesearch_fails=False, condition=None, maxls=20, decrease_factor=None, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, theta=1.0, history_size=10, use_gamma=True, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)
- Parameters
fun (Callable) –
value_and_grad (Union[bool, Callable]) –
has_aux (bool) –
maxiter (int) –
tol (float) –
stepsize (Union[float, Callable[[Any], float]]) –
linesearch (str) –
linesearch_init (str) –
stop_if_linesearch_fails (bool) –
condition (Optional[Any]) –
maxls (int) –
decrease_factor (Optional[Any]) –
increase_factor (float) –
max_stepsize (float) –
min_stepsize (float) –
theta (float) –
history_size (int) –
use_gamma (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable[[Any], Any]]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
verbose (bool) –
- Return type
None
Methods
__init__
(fun[, value_and_grad, has_aux, ...])attribute_names
()attribute_values
()init_state
(init_params, bounds, *args, **kwargs)Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(sol, bounds, *args, **kwargs)Optimality function mapping compatible with @custom_root.
run
(init_params, *args, **kwargs)Runs the optimization loop.
update
(params, state, bounds, *args, **kwargs)Performs one iteration of LBFGS.
Attributes
linesearch
theta
- init_state(init_params, bounds, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.bounds (
Optional
[Any
]) – an optional tuple (lb, ub) of pytrees with structure identical to init_params, representing box constraints.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
LbfgsbState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(sol, bounds, *args, **kwargs)[source]
Optimality function mapping compatible with @custom_root.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- update(params, state, bounds, *args, **kwargs)[source]
Performs one iteration of LBFGS.
- Parameters
params (
Any
) – pytree containing the parameters.state (
LbfgsbState
) – named tuple containing the solver state.bounds (
Optional
[Any
]) – an optional tuple (lb, ub) of pytrees with structure identical to init_params, representing box constraints.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)