jaxopt.BacktrackingLineSearch

class jaxopt.BacktrackingLineSearch(fun, value_and_grad=False, has_aux=False, maxiter=30, tol=0.0, condition='strong-wolfe', c1=0.0001, c2=0.9, decrease_factor=0.8, max_stepsize=1.0, verbose=0, jit='auto', unroll='auto')[source]

Backtracking line search.

Supports complex variables.

Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • condition (str) –

  • c1 (float) –

  • c2 (float) –

  • decrease_factor (float) –

  • max_stepsize (float) –

  • verbose (int) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a function of the form fun(params, *args, **kwargs), where params are parameters of the model, *args and **kwargs are additional arguments.

Type

Callable

value_and_grad

if False, fun should return the function value only. If True, fun should return both the function value and the gradient.

Type

bool

has_aux

if False, fun should return the function value only. If True, fun should return a pair (value, aux) where aux is a pytree of auxiliary values.

Type

bool

condition

either “armijo”, “goldstein”, “strong-wolfe” or “wolfe”.

Type

str

c1

constant used by the (strong) Wolfe condition.

Type

float

c2

constant strictly less than 1 used by the (strong) Wolfe condition.

Type

float

decrease_factor

factor by which to decrease the stepsize during line search (default: 0.8).

Type

float

max_stepsize

upper bound on stepsize.

Type

float

maxiter

maximum number of line search iterations.

Type

int

tol

tolerance of the stopping criterion.

Type

float

verbose

whether to print error on every iteration or not. verbose=True will automatically disable jit.

Type

int

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

__init__(fun, value_and_grad=False, has_aux=False, maxiter=30, tol=0.0, condition='strong-wolfe', c1=0.0001, c2=0.9, decrease_factor=0.8, max_stepsize=1.0, verbose=0, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • condition (str) –

  • c1 (float) –

  • c2 (float) –

  • decrease_factor (float) –

  • max_stepsize (float) –

  • verbose (int) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun[, value_and_grad, has_aux, ...])

attribute_names()

attribute_values()

init_state(init_stepsize, params[, value, ...])

Initialize the line search state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_stepsize, params[, value, grad, ...])

Runs the optimization loop.

update(stepsize, state, params[, value, ...])

Performs one iteration of backtracking line search.

Attributes

c1

c2

condition

decrease_factor

has_aux

jit

max_stepsize

maxiter

tol

unroll

value_and_grad

verbose

fun

init_state(init_stepsize, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})[source]

Initialize the line search state.

Parameters
  • init_stepsize (float) – initial step size value.

  • params (Any) – current parameters.

  • value (Optional[float]) – current function value (recomputed if None).

  • grad (Optional[Any]) – current gradient (recomputed if None).

  • descent_direction (Optional[Any]) – ignored.

  • fun_args (list) – additional positional arguments to be passed to fun.

  • fun_kwargs (dict) – additional keyword arguments to be passed to fun.

Return type

BacktrackingLineSearchState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_stepsize, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})

Runs the optimization loop.

Parameters
  • init_params – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

  • init_stepsize (float) –

  • params (Any) –

  • value (Optional[float]) –

  • grad (Optional[Any]) –

  • descent_direction (Optional[Any]) –

  • fun_args (list) –

  • fun_kwargs (dict) –

Return type

LineSearchStep

Returns

(params, state)

update(stepsize, state, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})[source]

Performs one iteration of backtracking line search.

Parameters
  • stepsize (float) – current estimate of the step size.

  • state (NamedTuple) – named tuple containing the line search state.

  • params (Any) – current parameters.

  • value (Optional[float]) – current function value (recomputed if None).

  • grad (Optional[Any]) – current gradient (recomputed if None).

  • descent_direction (Optional[Any]) – descent direction (negative gradient if None).

  • fun_args (list) – additional positional arguments to be passed to fun.

  • fun_kwargs (dict) – additional keyword arguments to be passed to fun.

Return type

LineSearchStep

Returns

(params, state)