jaxopt.Broyden

class jaxopt.Broyden(fun, has_aux=False, maxiter=500, tol=0.001, stepsize=0.0, linesearch='backtracking', stop_if_linesearch_fails=False, condition='wolfe', maxls=15, decrease_factor=0.8, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, history_size=None, gamma=1.0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)[source]

Limited-memory Broyden solver.

This method is a quasi-Newton approach to root finding. While similar to L-BFGS in spirit, it is not applied in the same situations: indeed, because the function whose root we are looking for is not necessarily a gradient, its Jacobian (i.e. its Hessian in the optimization case) is not necessarily symmetric. As a consequence, we cannot include symmetry in the secant conditions defining the updates of the Broyden matrices, and therefore the resulting Jacobian approximation is not symmetric, while it is for L-BFGS. Another consequence is that each Broyden update is of rank-1 while it is rank-2 for L-BFGS.

Parameters
  • fun (Callable) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • stepsize (Union[float, Callable]) –

  • linesearch (str) –

  • stop_if_linesearch_fails (bool) –

  • condition (str) –

  • maxls (int) –

  • decrease_factor (float) –

  • increase_factor (float) –

  • max_stepsize (float) –

  • min_stepsize (float) –

  • history_size (int) –

  • gamma (float) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

  • verbose (bool) –

fun

a function of the form fun(x, *args, **kwargs).

Type

Callable

has_aux

whether fun outputs auxiliary data or not. If has_aux is False, fun is expected to be

scalar-valued.

If has_aux is True, then we have one of the following

two cases.

At each iteration of the algorithm, the auxiliary outputs are stored

in state.aux.

Type

bool

maxiter

maximum number of Broyden iterations.

Type

int

tol

tolerance of the stopping criterion.

Type

float

stepsize

a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the positive stepsize to use at each iteration.

Type

Union[float, Callable]

linesearch

the type of line search to use: for now only “backtracking” for backtracking line search is available.

Type

str

stop_if_linesearch_fails

whether to stop iterations if the line search fails. When True, this matches the behavior of core JAX.

Type

bool

maxls

maximum number of iterations to use in the line search.

Type

int

decrease_factor

factor by which to decrease the stepsize during line search (default: 0.8).

Type

float

increase_factor

factor by which to increase the stepsize during line search (default: 1.5).

Type

float

max_stepsize

upper bound on stepsize.

Type

float

min_stepsize

lower bound on stepsize.

Type

float

history_size

size of the memory to use.

Type

int

gamma

the initialization of the inverse Jacobian is going to be gamma * I.

Type

float

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

Reference:

Charles G. Broyden. A Class of Methods for Solving Nonlinear Simultaneous Equations. Equation (4.5) (page 581).

__init__(fun, has_aux=False, maxiter=500, tol=0.001, stepsize=0.0, linesearch='backtracking', stop_if_linesearch_fails=False, condition='wolfe', maxls=15, decrease_factor=0.8, increase_factor=1.5, max_stepsize=1.0, min_stepsize=1e-06, history_size=None, gamma=1.0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=False)
Parameters
  • fun (Callable) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • stepsize (Union[float, Callable]) –

  • linesearch (str) –

  • stop_if_linesearch_fails (bool) –

  • condition (str) –

  • maxls (int) –

  • decrease_factor (float) –

  • increase_factor (float) –

  • max_stepsize (float) –

  • min_stepsize (float) –

  • history_size (Optional[int]) –

  • gamma (float) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

  • verbose (bool) –

Return type

None

Methods

__init__(fun[, has_aux, maxiter, tol, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of Broyden.

Attributes

condition

decrease_factor

gamma

has_aux

history_size

implicit_diff

implicit_diff_solve

increase_factor

jit

linesearch

max_stepsize

maxiter

maxls

min_stepsize

stepsize

stop_if_linesearch_fails

tol

unroll

verbose

fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

BroydenState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of Broyden.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (BroydenState) – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)