jaxopt.HagerZhangLineSearch

class jaxopt.HagerZhangLineSearch(fun, value_and_grad=False, has_aux=False, maxiter=30, tol=0.0, c1=0.1, c2=0.9, expansion_factor=5.0, shrinkage_factor=0.66, max_stepsize=1.0, verbose=0, jit='auto', unroll='auto')[source]

Hager-Zhang line search.

Supports complex variables.

Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • c1 (float) –

  • c2 (float) –

  • expansion_factor (float) –

  • shrinkage_factor (float) –

  • max_stepsize (float) –

  • verbose (int) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a function of the form fun(params, *args, **kwargs), where params are parameters of the model, *args and **kwargs are additional arguments.

Type

Callable

value_and_grad

if False, fun should return the function value only. If True, fun should return both the function value and the gradient.

Type

bool

has_aux

if False, fun should return the function value only. If True, fun should return a pair (value, aux) where aux is a pytree of auxiliary values.

Type

bool

c1

constant used by the Wolfe and Approximate Wolfe condition.

Type

float

c2

constant strictly less than 1 used by the Wolfe and Approximate Wolfe condition.

Type

float

max_stepsize

upper bound on stepsize.

Type

float

maxiter

maximum number of line search iterations.

Type

int

tol

tolerance of the stopping criterion.

Type

float

verbose

whether to print error on every iteration or not. verbose=True will automatically disable jit.

Type

int

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

__init__(fun, value_and_grad=False, has_aux=False, maxiter=30, tol=0.0, c1=0.1, c2=0.9, expansion_factor=5.0, shrinkage_factor=0.66, max_stepsize=1.0, verbose=0, jit='auto', unroll='auto')
Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • c1 (float) –

  • c2 (float) –

  • expansion_factor (float) –

  • shrinkage_factor (float) –

  • max_stepsize (float) –

  • verbose (int) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun[, value_and_grad, has_aux, ...])

attribute_names()

attribute_values()

init_state(init_stepsize, params[, value, ...])

Initialize the line search state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_stepsize, params[, value, grad, ...])

Runs the optimization loop.

update(stepsize, state, params[, value, ...])

Performs one iteration of Hager-Zhang line search.

Attributes

approximate_wolfe_threshold

c1

c2

expansion_factor

has_aux

jit

max_stepsize

maxiter

shrinkage_factor

tol

unroll

value_and_grad

verbose

fun

init_state(init_stepsize, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})[source]

Initialize the line search state.

Parameters
  • init_stepsize (float) – initial step size value. This is ignored by the linesearch.

  • params (Any) – current parameters.

  • value (Optional[float]) – current function value (recomputed if None).

  • grad (Optional[Any]) – current gradient (recomputed if None).

  • descent_direction (Optional[Any]) – ignored.

  • fun_args (list) – additional positional arguments to be passed to fun.

  • fun_kwargs (dict) – additional keyword arguments to be passed to fun.

Return type

HagerZhangLineSearchState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_stepsize, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})

Runs the optimization loop.

Parameters
  • init_params – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

  • init_stepsize (float) –

  • params (Any) –

  • value (Optional[float]) –

  • grad (Optional[Any]) –

  • descent_direction (Optional[Any]) –

  • fun_args (list) –

  • fun_kwargs (dict) –

Return type

LineSearchStep

Returns

(params, state)

update(stepsize, state, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})[source]

Performs one iteration of Hager-Zhang line search.

Parameters
  • stepsize (float) – current estimate of the step size.

  • state (NamedTuple) – named tuple containing the line search state.

  • params (Any) – current parameters.

  • value (Optional[float]) – current function value (recomputed if None).

  • grad (Optional[Any]) – current gradient (recomputed if None).

  • descent_direction (Optional[Any]) – descent direction (negative gradient if None).

  • fun_args (list) – additional positional arguments to be passed to fun.

  • fun_kwargs (dict) – additional keyword arguments to be passed to fun.

Return type

LineSearchStep

Returns

(params, state)