jaxopt.LevenbergMarquardt

class jaxopt.LevenbergMarquardt(residual_fun, maxiter=30, damping_parameter=1e-06, stop_criterion='grad-l2-norm', tol=0.001, xtol=0.001, gtol=0.001, solver=<function solve_cg>, geodesic=False, verbose=False, jac_fun=None, materialize_jac=False, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]

Levenberg-Marquardt nonlinear least-squares solver.

Given the residual function func (x): R^n -> R^m, least_squares finds a local minimum of the cost function F(x):

` argmin_x F(x) = 0.5 * sum(f_i(x)**2), i = 0, ..., m - 1 f(x) = func(x, *args) `

If stop_criterion is ‘madsen-nielsen’, the convergence is achieved once the coeff update satisfies ||dcoeffs||_2 <= xtol * (||coeffs||_2 + xtol) `` or the gradient satisfies ``||grad(f)||_inf <= gtol.

Parameters
  • residual_fun (Callable) –

  • maxiter (int) –

  • damping_parameter (float) –

  • stop_criterion (Literal['grad-l2-norm', 'madsen-nielsen']) –

  • tol (float) –

  • xtol (float) –

  • gtol (float) –

  • solver (Union[Literal['cholesky', 'inv'], ~typing.Callable]) –

  • geodesic (bool) –

  • verbose (bool) –

  • jac_fun (Optional[Callable[[...], Array]]) –

  • materialize_jac (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

residual_fun

a smooth function of the form residual_fun(x, *args, **kwargs).

Type

Callable

maxiter

maximum increase_factormber of iterations.

Type

int

damping_parameter

The parameter which adds a correction to the equation derived for updating the coefficients using Gauss-Newton method. Please see section 3.2. of K. Madsen et al. in the book “Methods for nonlinear least squares problems” for more information.

Type

float

stop_criterion

The criterion to use for the convergence of the while loop. e.g., for ‘madsen-nielsen’ the criteria is to satisfy the two equations for delta_params and gradient that is mentioned above. If ‘grad-l2’ is selected, the convergence is achieved if l2 of gradient is smaller or equal to tol.

Type

Literal[‘grad-l2-norm’, ‘madsen-nielsen’]

tol

tolerance.

Type

float

xtol

float, optional The convergence tolerance for the second norm of the coefficient update.

Type

float

gtol

float, optional The convergence tolerance for the inf norm of the residual gradient.

Type

float

solver

str, optional The solver to use when finding delta_params, the update to the params in each iteration. This is done through solving a system of linear equation Ax=b. ‘cholesky’ (Cholesky factorization), ‘inv’ (explicit multiplication with matrix inverse). The user can provide custom solvers, for example using jaxopt.linear_solve.solve_cg which are more scalable for runtime but take longer compilations. ‘cholesky’ is faster than ‘inv’ since it uses the symmetry feature of A.

Type

Union[Literal[‘cholesky’, ‘inv’], Callable]

Parameters
  • matvec (Callable) –

  • b (Any) –

  • ridge (Optional[float]) –

  • init (Optional[Any]) –

Return type

Any

geodesic

bool, if we would like to include the geodesic acceleration when solving for the delta_params in every iteration.

Type

bool

contribution_ratio_threshold

float, the threshold for acceleration/velocity ratio. We update the parameters in the algorithm only if the ratio is smaller than this threshold value.

implicit_diff

bool, whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

verbose

bool, whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

jit

whether to JIT-compile the bisection loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the bisection loop (default: “auto”).

Type

Union[str, bool]

Reference: This algorithm is for finding the best fit parameters based on the

algorithm 6.18 provided by K. Madsen & H. B. Nielsen in the book “Introduction to Optimization and Data Fitting”.

__init__(residual_fun, maxiter=30, damping_parameter=1e-06, stop_criterion='grad-l2-norm', tol=0.001, xtol=0.001, gtol=0.001, solver=<function solve_cg>, geodesic=False, verbose=False, jac_fun=None, materialize_jac=False, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
Parameters
  • residual_fun (Callable) –

  • maxiter (int) –

  • damping_parameter (float) –

  • stop_criterion (Literal['grad-l2-norm', 'madsen-nielsen']) –

  • tol (float) –

  • xtol (float) –

  • gtol (float) –

  • solver (Union[Literal['cholesky', 'inv'], ~typing.Callable]) –

  • geodesic (bool) –

  • verbose (bool) –

  • jac_fun (Optional[Callable[[...], Array]]) –

  • materialize_jac (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(residual_fun[, maxiter, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

solver(b[, ridge, init])

Solves A x = b using conjugate gradient.

update(params, state, *args, **kwargs)

Performs one iteration of the least-squares solver.

update_state_using_delta_params(loss_curr, ...)

The function to return state variables based on delta_params.

update_state_using_gain_ratio(gain_ratio, ...)

The function to return state variables based on gain ratio.

Attributes

contribution_ratio_threshold

damping_parameter

geodesic

gtol

has_aux

implicit_diff

implicit_diff_solve

jac_fun

jit

materialize_jac

maxiter

stop_criterion

tol

unroll

verbose

xtol

residual_fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to residual_fun.

  • **kwargs – additional keyword arguments to be passed to residual_fun.

Return type

LevenbergMarquardtState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

solver(b, ridge=None, init=None, **kwargs)

Solves A x = b using conjugate gradient.

It assumes that A is a Hermitian, positive definite matrix.

Parameters
  • matvec (Callable) – product between A and a vector.

  • b (Any) – pytree.

  • ridge (Optional[float]) – optional ridge regularization.

  • init (Optional[Any]) – optional initialization to be used by conjugate gradient.

  • **kwargs – additional keyword arguments for solver.

Return type

Any

Returns

pytree with same structure as b.

update(params, state, *args, **kwargs)[source]

Performs one iteration of the least-squares solver.

Parameters
  • params – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

Return type

OptStep

Returns

(params, state)

update_state_using_delta_params(loss_curr, params, delta_params, contribution_ratio_diff, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux, *args, **kwargs)[source]

The function to return state variables based on delta_params.

Define the functions required for the major conditional of the algorithm, which checks the magnitude of dparams and checks if it is small enough. for the value of dparams.

update_state_using_gain_ratio(gain_ratio, contribution_ratio_diff, gain_ratio_test_init_state, *args, **kwargs)[source]

The function to return state variables based on gain ratio. Please see by page 120-121 of the book “Introduction to Optimization and Data Fitting” by K. Madsen & H. B. Nielsen for details.