jaxopt.LBFGS

class jaxopt.LBFGS(fun, value_and_grad=False, has_aux=False, maxiter=500, tol=0.001, stepsize=0.0, linesearch='zoom', linesearch_init='increase', stop_if_linesearch_fails=False, condition=None, maxls=15, decrease_factor=None, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, history_size=10, use_gamma=True, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)[source]

LBFGS solver.

Supports complex variables, see second reference.

Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • stepsize (Union[float, Callable]) –

  • linesearch (str) –

  • linesearch_init (str) –

  • stop_if_linesearch_fails (bool) –

  • condition (Any) –

  • maxls (int) –

  • decrease_factor (Any) –

  • increase_factor (float) –

  • max_stepsize (float) –

  • min_stepsize (float) –

  • history_size (int) –

  • use_gamma (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

  • verbose (bool) –

fun

a smooth function of the form fun(x, *args, **kwargs).

Type

Callable

value_and_grad

whether fun just returns the value (False) or both the value and gradient (True).

Type

bool

has_aux

whether fun outputs auxiliary data or not. If has_aux is False, fun is expected to be scalar-valued. If has_aux is True, then we have one of the following two cases. If value_and_grad is False, the output should be value, aux = fun(...). If value_and_grad == True, the output should be (value, aux), grad = fun(...). At each iteration of the algorithm, the auxiliary outputs are stored in state.aux.

Type

bool

maxiter

maximum number of proximal gradient descent iterations.

Type

int

tol

tolerance of the stopping criterion.

Type

float

stepsize

a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.

Type

Union[float, Callable]

linesearch

the type of line search to use: “backtracking” for backtracking line search, “zoom” for zoom line search or “hager-zhang” for Hager-Zhang line search.

Type

str

linesearch_init

strategy for line-search initialization. By default, it will use “increase”, which will increase the step-size by a factor of increase_factor at each iteration if the step-size is larger than min_stepsize, and set it to max_stepsize otherwise. Other choices are “max”, that initializes the step-size to max_stepsize at every iteration, and “current”, that uses the step-size from the previous iteration.

Type

str

stop_if_linesearch_fails

whether to stop iterations if the line search fails. When True, this matches the behavior of core JAX.

Type

bool

condition

Deprecated. Condition used to select the stepsize when using backtracking linesearch.

Type

Any

maxls

maximum number of iterations to use in the line search.

Type

int

decrease_factor

Deprecated. Factor by which to decrease the stepsize during line search when using backtracking linesearch (default: 0.8).

Type

Any

increase_factor

factor by which to increase the stepsize during line search (default: 1.5).

Type

float

max_stepsize

upper bound on stepsize.

Type

float

min_stepsize

lower bound on stepsize.

Type

float

history_size

size of the memory to use.

Type

int

use_gamma

whether to initialize the inverse Hessian approximation with gamma * I, where gamma is chosen following equation (7.20) of ‘Numerical Optimization’ (reference below). If use_gamma is set to False, the identity is used as initialization.

Type

bool

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

References

Jorge Nocedal and Stephen Wright. Numerical Optimization, second edition. Algorithm 7.5 (page 179).

Laurent Sorber, Marc van Barel, and Lieven de Lathauwer. Unconstrained Optimization of Real Functions in Complex Variables. SIAM J. Optim., Vol. 22, No. 3, pp. 879-898

__init__(fun, value_and_grad=False, has_aux=False, maxiter=500, tol=0.001, stepsize=0.0, linesearch='zoom', linesearch_init='increase', stop_if_linesearch_fails=False, condition=None, maxls=15, decrease_factor=None, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, history_size=10, use_gamma=True, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)
Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • stepsize (Union[float, Callable]) –

  • linesearch (str) –

  • linesearch_init (str) –

  • stop_if_linesearch_fails (bool) –

  • condition (Optional[Any]) –

  • maxls (int) –

  • decrease_factor (Optional[Any]) –

  • increase_factor (float) –

  • max_stepsize (float) –

  • min_stepsize (float) –

  • history_size (int) –

  • use_gamma (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

  • verbose (bool) –

Return type

None

Methods

__init__(fun[, value_and_grad, has_aux, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of LBFGS.

Attributes

condition

decrease_factor

has_aux

history_size

implicit_diff

implicit_diff_solve

increase_factor

jit

linesearch

linesearch_init

max_stepsize

maxiter

maxls

min_stepsize

stepsize

stop_if_linesearch_fails

tol

unroll

use_gamma

value_and_grad

verbose

fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

LbfgsState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of LBFGS.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (LbfgsState) – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)