jaxopt.BacktrackingLineSearch
- class jaxopt.BacktrackingLineSearch(fun, value_and_grad=False, has_aux=False, maxiter=30, tol=0.0, condition='strong-wolfe', c1=0.0001, c2=0.9, decrease_factor=0.8, max_stepsize=1.0, verbose=0, jit='auto', unroll='auto')[source]
Backtracking line search.
Supports complex variables.
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
maxiter (int) –
tol (float) –
condition (str) –
c1 (float) –
c2 (float) –
decrease_factor (float) –
max_stepsize (float) –
verbose (int) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fun
a function of the form
fun(params, *args, **kwargs)
, whereparams
are parameters of the model,*args
and**kwargs
are additional arguments.- Type
Callable
- value_and_grad
if
False
,fun
should return the function value only. IfTrue
,fun
should return both the function value and the gradient.- Type
bool
- has_aux
if
False
,fun
should return the function value only. IfTrue
,fun
should return a pair(value, aux)
whereaux
is a pytree of auxiliary values.- Type
bool
- condition
either “armijo”, “goldstein”, “strong-wolfe” or “wolfe”.
- Type
str
- c1
constant used by the (strong) Wolfe condition.
- Type
float
- c2
constant strictly less than 1 used by the (strong) Wolfe condition.
- Type
float
- decrease_factor
factor by which to decrease the stepsize during line search (default: 0.8).
- Type
float
- max_stepsize
upper bound on stepsize.
- Type
float
- maxiter
maximum number of line search iterations.
- Type
int
- tol
tolerance of the stopping criterion.
- Type
float
- verbose
whether to print error on every iteration or not. verbose=True will automatically disable jit.
- Type
int
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
- __init__(fun, value_and_grad=False, has_aux=False, maxiter=30, tol=0.0, condition='strong-wolfe', c1=0.0001, c2=0.9, decrease_factor=0.8, max_stepsize=1.0, verbose=0, jit='auto', unroll='auto')
- Parameters
fun (Callable) –
value_and_grad (bool) –
has_aux (bool) –
maxiter (int) –
tol (float) –
condition (str) –
c1 (float) –
c2 (float) –
decrease_factor (float) –
max_stepsize (float) –
verbose (int) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(fun[, value_and_grad, has_aux, ...])attribute_names
()attribute_values
()init_state
(init_stepsize, params[, value, ...])Initialize the line search state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
run
(init_stepsize, params[, value, grad, ...])Runs the optimization loop.
update
(stepsize, state, params[, value, ...])Performs one iteration of backtracking line search.
Attributes
- init_state(init_stepsize, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})[source]
Initialize the line search state.
- Parameters
init_stepsize (
float
) – initial step size value.params (
Any
) – current parameters.value (
Optional
[float
]) – current function value (recomputed if None).grad (
Optional
[Any
]) – current gradient (recomputed if None).descent_direction (
Optional
[Any
]) – ignored.fun_args (
list
) – additional positional arguments to be passed tofun
.fun_kwargs (
dict
) – additional keyword arguments to be passed tofun
.
- Return type
BacktrackingLineSearchState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- run(init_stepsize, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})
Runs the optimization loop.
- Parameters
init_params – pytree containing the initial parameters.
*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
init_stepsize (float) –
params (Any) –
value (Optional[float]) –
grad (Optional[Any]) –
descent_direction (Optional[Any]) –
fun_args (list) –
fun_kwargs (dict) –
- Return type
LineSearchStep
- Returns
(params, state)
- update(stepsize, state, params, value=None, grad=None, descent_direction=None, fun_args=[], fun_kwargs={})[source]
Performs one iteration of backtracking line search.
- Parameters
stepsize (
float
) – current estimate of the step size.state (
NamedTuple
) – named tuple containing the line search state.params (
Any
) – current parameters.value (
Optional
[float
]) – current function value (recomputed if None).grad (
Optional
[Any
]) – current gradient (recomputed if None).descent_direction (
Optional
[Any
]) – descent direction (negative gradient if None).fun_args (
list
) – additional positional arguments to be passed tofun
.fun_kwargs (
dict
) – additional keyword arguments to be passed tofun
.
- Return type
LineSearchStep
- Returns
(params, state)