jaxopt.LevenbergMarquardt
- class jaxopt.LevenbergMarquardt(residual_fun, maxiter=30, damping_parameter=1e-06, stop_criterion='grad-l2-norm', tol=0.001, xtol=0.001, gtol=0.001, solver=<function solve_cg>, geodesic=False, verbose=False, jac_fun=None, materialize_jac=False, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]
Levenberg-Marquardt nonlinear least-squares solver.
Given the residual function func (x): R^n -> R^m, least_squares finds a local minimum of the cost function F(x):
` argmin_x F(x) = 0.5 * sum(f_i(x)**2), i = 0, ..., m - 1 f(x) = func(x, *args) `
If stop_criterion is ‘madsen-nielsen’, the convergence is achieved once the coeff update satisfies
||dcoeffs||_2 <= xtol * (||coeffs||_2 + xtol) `` or the gradient satisfies ``||grad(f)||_inf <= gtol
.- Parameters
residual_fun (Callable) –
maxiter (int) –
damping_parameter (float) –
stop_criterion (Literal['grad-l2-norm', 'madsen-nielsen']) –
tol (float) –
xtol (float) –
gtol (float) –
solver (Union[Literal['cholesky', 'inv'], ~typing.Callable]) –
geodesic (bool) –
verbose (bool) –
jac_fun (Optional[Callable[[...], Array]]) –
materialize_jac (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- residual_fun
a smooth function of the form
residual_fun(x, *args, **kwargs)
.- Type
Callable
- maxiter
maximum increase_factormber of iterations.
- Type
int
- damping_parameter
The parameter which adds a correction to the equation derived for updating the coefficients using Gauss-Newton method. Please see section 3.2. of K. Madsen et al. in the book “Methods for nonlinear least squares problems” for more information.
- Type
float
- stop_criterion
The criterion to use for the convergence of the while loop. e.g., for ‘madsen-nielsen’ the criteria is to satisfy the two equations for delta_params and gradient that is mentioned above. If ‘grad-l2’ is selected, the convergence is achieved if l2 of gradient is smaller or equal to tol.
- Type
Literal[‘grad-l2-norm’, ‘madsen-nielsen’]
- tol
tolerance.
- Type
float
- xtol
float, optional The convergence tolerance for the second norm of the coefficient update.
- Type
float
- gtol
float, optional The convergence tolerance for the inf norm of the residual gradient.
- Type
float
- solver
str, optional The solver to use when finding delta_params, the update to the params in each iteration. This is done through solving a system of linear equation Ax=b. ‘cholesky’ (Cholesky factorization), ‘inv’ (explicit multiplication with matrix inverse). The user can provide custom solvers, for example using jaxopt.linear_solve.solve_cg which are more scalable for runtime but take longer compilations. ‘cholesky’ is faster than ‘inv’ since it uses the symmetry feature of A.
- Type
Union[Literal[‘cholesky’, ‘inv’], Callable]
- Parameters
matvec (Callable) –
b (Any) –
ridge (Optional[float]) –
init (Optional[Any]) –
- Return type
Any
- geodesic
bool, if we would like to include the geodesic acceleration when solving for the delta_params in every iteration.
- Type
bool
- contribution_ratio_threshold
float, the threshold for acceleration/velocity ratio. We update the parameters in the algorithm only if the ratio is smaller than this threshold value.
- implicit_diff
bool, whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- verbose
bool, whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
bool
- jit
whether to JIT-compile the bisection loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the bisection loop (default: “auto”).
- Type
Union[str, bool]
- Reference: This algorithm is for finding the best fit parameters based on the
algorithm 6.18 provided by K. Madsen & H. B. Nielsen in the book “Introduction to Optimization and Data Fitting”.
- __init__(residual_fun, maxiter=30, damping_parameter=1e-06, stop_criterion='grad-l2-norm', tol=0.001, xtol=0.001, gtol=0.001, solver=<function solve_cg>, geodesic=False, verbose=False, jac_fun=None, materialize_jac=False, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
- Parameters
residual_fun (Callable) –
maxiter (int) –
damping_parameter (float) –
stop_criterion (Literal['grad-l2-norm', 'madsen-nielsen']) –
tol (float) –
xtol (float) –
gtol (float) –
solver (Union[Literal['cholesky', 'inv'], ~typing.Callable]) –
geodesic (bool) –
verbose (bool) –
jac_fun (Optional[Callable[[...], Array]]) –
materialize_jac (bool) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(residual_fun[, maxiter, ...])attribute_names
()attribute_values
()init_state
(init_params, *args, **kwargs)Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(params, *args, **kwargs)Optimality function mapping compatible with
@custom_root
.run
(init_params, *args, **kwargs)Runs the optimization loop.
solver
(b[, ridge, init])Solves
A x = b
using conjugate gradient.update
(params, state, *args, **kwargs)Performs one iteration of the least-squares solver.
update_state_using_delta_params
(loss_curr, ...)The function to return state variables based on delta_params.
update_state_using_gain_ratio
(gain_ratio, ...)The function to return state variables based on gain ratio.
Attributes
has_aux
jac_fun
materialize_jac
- init_state(init_params, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to
residual_fun
.**kwargs – additional keyword arguments to be passed to
residual_fun
.
- Return type
LevenbergMarquardtState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, *args, **kwargs)[source]
Optimality function mapping compatible with
@custom_root
.
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- solver(b, ridge=None, init=None, **kwargs)
Solves
A x = b
using conjugate gradient.It assumes that
A
is a Hermitian, positive definite matrix.- Parameters
matvec (
Callable
) – product betweenA
and a vector.b (
Any
) – pytree.ridge (
Optional
[float
]) – optional ridge regularization.init (
Optional
[Any
]) – optional initialization to be used by conjugate gradient.**kwargs – additional keyword arguments for solver.
- Return type
Any
- Returns
pytree with same structure as
b
.
- update(params, state, *args, **kwargs)[source]
Performs one iteration of the least-squares solver.
- Parameters
params – pytree containing the parameters.
state (
NamedTuple
) – named tuple containing the solver state.
- Return type
OptStep
- Returns
(params, state)
- update_state_using_delta_params(loss_curr, params, delta_params, contribution_ratio_diff, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux, *args, **kwargs)[source]
The function to return state variables based on delta_params.
Define the functions required for the major conditional of the algorithm, which checks the magnitude of dparams and checks if it is small enough. for the value of dparams.
- update_state_using_gain_ratio(gain_ratio, contribution_ratio_diff, gain_ratio_test_init_state, *args, **kwargs)[source]
The function to return state variables based on gain ratio. Please see by page 120-121 of the book “Introduction to Optimization and Data Fitting” by K. Madsen & H. B. Nielsen for details.