jaxopt.NonlinearCG

class jaxopt.NonlinearCG(fun, value_and_grad=False, has_aux=False, maxiter=100, tol=0.001, method='polak-ribiere', linesearch='zoom', linesearch_init='increase', condition=None, maxls=15, decrease_factor=None, increase_factor=1.2, max_stepsize=1.0, min_stepsize=1e-06, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=0)[source]

Nonlinear conjugate gradient solver.

Supports complex variables, see second reference.

Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • method (str) –

  • linesearch (str) –

  • linesearch_init (str) –

  • condition (Any) –

  • maxls (int) –

  • decrease_factor (Any) –

  • increase_factor (float) –

  • max_stepsize (float) –

  • min_stepsize (float) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

  • verbose (int) –

fun

a smooth function of the form fun(x, *args, **kwargs).

Type

Callable

value_and_grad

whether fun just returns the value (False) or both the value and gradient (True).

Type

bool

has_aux

whether fun outputs auxiliary data or not. If value_and_grad == False, the output should be value, aux = fun(...). If value_and_grad == True, the output should be (value, aux), grad = fun(...). The auxiliary outputs are stored in state.aux.

Type

bool

maxiter

maximum number of proximal gradient descent iterations.

Type

int

tol

tolerance of the stopping criterion.

Type

float

method

which variant to calculate the beta parameter in Nonlinear CG. “polak-ribiere”, “fletcher-reeves”, “hestenes-stiefel” (default: “polak-ribiere”)

Type

str

linesearch

the type of line search to use: “backtracking” for backtracking line search, “zoom” for zoom line search or “hager-zhang” for Hager-Zhang line search.

Type

str

linesearch_init

strategy for line-search initialization. By default, it will use “increase”, which will increased the step-size by a factor of increase_factor at each iteration if the step-size is larger than min_stepsize, and set it to max_stepsize otherwise. Other choices are “max”, that initializes the step-size to max_stepsize at every iteration, and “current”, that uses the step-size from the previous iteration.

Type

str

condition

Deprecated. Condition used to select the stepsize when using backtracking linesearch.

Type

Any

maxls

maximum number of iterations to use in the line search.

Type

int

decrease_factor

Deprecated. Factor by which to decrease the stepsize during line search when using backtracking linesearch (default: 0.8).

Type

Any

increase_factor

factor by which to increase the stepsize during line search (default: 1.2).

Type

float

max_stepsize

upper bound on stepsize.

Type

float

min_stepsize

lower bound on stepsize.

Type

float

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

int

References

Jorge Nocedal and Stephen Wright. Numerical Optimization, second edition. Algorithm 5.4 (page 121).

Laurent Sorber, Marc van Barel, and Lieven de Lathauwer. Unconstrained Optimization of Real Functions in Complex Variables. SIAM J. Optim., Vol. 22, No. 3, pp. 879-898

__init__(fun, value_and_grad=False, has_aux=False, maxiter=100, tol=0.001, method='polak-ribiere', linesearch='zoom', linesearch_init='increase', condition=None, maxls=15, decrease_factor=None, increase_factor=1.2, max_stepsize=1.0, min_stepsize=1e-06, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto', verbose=0)
Parameters
  • fun (Callable) –

  • value_and_grad (bool) –

  • has_aux (bool) –

  • maxiter (int) –

  • tol (float) –

  • method (str) –

  • linesearch (str) –

  • linesearch_init (str) –

  • condition (Optional[Any]) –

  • maxls (int) –

  • decrease_factor (Optional[Any]) –

  • increase_factor (float) –

  • max_stepsize (float) –

  • min_stepsize (float) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

  • verbose (int) –

Return type

None

Methods

__init__(fun[, value_and_grad, has_aux, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of Fletcher-Reeves Algorithm.

Attributes

condition

decrease_factor

has_aux

implicit_diff

implicit_diff_solve

increase_factor

jit

linesearch

linesearch_init

max_stepsize

maxiter

maxls

method

min_stepsize

tol

unroll

value_and_grad

verbose

fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

NonlinearCGState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of Fletcher-Reeves Algorithm.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (NonlinearCGState) – named tuple containing the solver state.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)