# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Levenberg-Marquardt algorithm in JAX."""
from typing import Any
from typing import Callable
from typing import Literal
from typing import NamedTuple
from typing import Optional
from typing import Union
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jaxopt._src import base
from jaxopt._src.linear_solve import solve_cg
from jaxopt._src.linear_solve import solve_cholesky
from jaxopt._src.linear_solve import solve_inv
from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add, tree_mul
class LevenbergMarquardtState(NamedTuple):
"""Named tuple containing state information."""
iter_num: int
damping_factor: float
increase_factor: float
residual: Any
loss: Any
delta: Any
error: float
gradient: Any
jt: Any
jtj: Any
hess_res: Any
aux: Optional[Any] = None
[docs]@dataclass(eq=False)
class LevenbergMarquardt(base.IterativeSolver):
"""Levenberg-Marquardt nonlinear least-squares solver.
Given the residual function `func` (x): R^n -> R^m, `least_squares` finds a
local minimum of the cost function F(x):
```
argmin_x F(x) = 0.5 * sum(f_i(x)**2), i = 0, ..., m - 1
f(x) = func(x, *args)
```
If stop_criterion is 'madsen-nielsen', the convergence is achieved once the
coeff update satisfies ``||dcoeffs||_2 <= xtol * (||coeffs||_2 + xtol) `` or
the gradient satisfies ``||grad(f)||_inf <= gtol``.
Attributes:
residual_fun: a smooth function of the form ``residual_fun(x, *args,
**kwargs)``.
maxiter: maximum increase_factormber of iterations.
damping_parameter: The parameter which adds a correction to the equation
derived for updating the coefficients using Gauss-Newton method. Please
see section 3.2. of K. Madsen et al. in the book "Methods for nonlinear
least squares problems" for more information.
stop_criterion: The criterion to use for the convergence of the while loop.
e.g., for 'madsen-nielsen' the criteria is to satisfy the two equations
for delta_params and gradient that is mentioned above. If 'grad-l2' is
selected, the convergence is achieved if l2 of gradient is smaller or
equal to tol.
tol: tolerance.
xtol: float, optional The convergence tolerance for the second norm of the
coefficient update.
gtol: float, optional The convergence tolerance for the inf norm of the
residual gradient.
solver: str, optional The solver to use when finding delta_params, the
update to the params in each iteration. This is done through solving a
system of linear equation Ax=b. 'cholesky' (Cholesky factorization), 'inv'
(explicit multiplication with matrix inverse). The user can provide custom
solvers, for example using jaxopt.linear_solve.solve_cg which are more
scalable for runtime but take longer compilations. 'cholesky' is
faster than 'inv' since it uses the symmetry feature of A.
geodesic: bool, if we would like to include the geodesic acceleration when
solving for the delta_params in every iteration.
contribution_ratio_threshold: float, the threshold for acceleration/velocity
ratio. We update the parameters in the algorithm only if the ratio is
smaller than this threshold value.
implicit_diff: bool, whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
verbose: bool, whether to print error on every iteration or not.
Warning: verbose=True will automatically disable jit.
jit: whether to JIT-compile the bisection loop (default: "auto").
unroll: whether to unroll the bisection loop (default: "auto").
Reference: This algorithm is for finding the best fit parameters based on the
algorithm 6.18 provided by K. Madsen & H. B. Nielsen in the book
"Introduction to Optimization and Data Fitting".
"""
residual_fun: Callable
maxiter: int = 30
damping_parameter: float = 1e-6
stop_criterion: Literal['grad-l2-norm', 'madsen-nielsen'] = 'grad-l2-norm'
tol: float = 1e-3
xtol: float = 1e-3
gtol: float = 1e-3
solver: Union[Literal['cholesky', 'inv'], Callable] = solve_cg
geodesic: bool = False
contribution_ratio_threshold = 0.75
verbose: bool = False
jac_fun: Optional[Callable[..., jnp.ndarray]] = None
materialize_jac: bool = False
implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
has_aux: bool = False
jit: base.AutoOrBoolean = 'auto'
unroll: base.AutoOrBoolean = 'auto'
# We are overriding the _cond_fun of the base solver to enable stopping based
# on gradient or delta_params
def _cond_fun(self, inputs):
params, state = inputs[0]
if self.verbose:
print_iteration(state)
if self.stop_criterion == 'madsen-nielsen':
tree_mul_term = self.xtol * (tree_l2_norm(params) - self.xtol)
return jnp.all(jnp.array([
tree_inf_norm(state.gradient) > self.gtol,
tree_l2_norm(state.delta) > tree_mul_term
]))
elif self.stop_criterion == 'grad-l2-norm':
return state.error > self.tol
else:
raise NotImplementedError
[docs] def init_state(self, init_params: Any, *args,
**kwargs) -> LevenbergMarquardtState:
"""Initialize the solver state.
Args:
init_params: pytree containing the initial parameters.
*args: additional positional arguments to be passed to ``residual_fun``.
**kwargs: additional keyword arguments to be passed to ``residual_fun``.
Returns:
state
"""
# Compute actual values of state variables at init_param
residual, aux = self._fun_with_aux(init_params, *args, **kwargs)
if self.materialize_jac:
jac = self._jac_fun(init_params, *args, **kwargs)
jt = jac.T
jtj = jt @ jac
gradient = jt @ residual
damping_factor = self.damping_parameter * jnp.max(jnp.diag(jtj))
if self.geodesic:
hess_res = self._hess_res_fun(init_params, *args, **kwargs)
else:
hess_res = None
else:
jt = None
jtj = None
hess_res = None
gradient = self._jt_op(init_params, residual, *args, **kwargs)
jtj_diag = self._jtj_diag_op(init_params, *args, **kwargs)
damping_factor = self.damping_parameter * jnp.max(jtj_diag)
delta_params = jnp.zeros_like(init_params)
if self.verbose:
print_header()
return LevenbergMarquardtState(
iter_num=jnp.asarray(0),
damping_factor=damping_factor,
increase_factor=2,
error=tree_l2_norm(gradient),
residual=residual,
loss=0.5 * (residual @ residual),
delta=delta_params,
gradient=gradient,
jt=jt,
jtj=jtj,
hess_res=hess_res,
aux=aux)
[docs] def update_state_using_gain_ratio(self, gain_ratio, contribution_ratio_diff,
gain_ratio_test_init_state, *args,
**kwargs):
"""The function to return state variables based on gain ratio.
Please see by page 120-121 of the book "Introduction to Optimization and
Data Fitting" by K. Madsen & H. B. Nielsen for details.
"""
def gain_ratio_test_true_func(params, damping_factor,
increase_factor, residual, gradient, jt, jtj,
hess_res, updated_params, aux):
params = updated_params
residual, aux = self._fun_with_aux(params, *args, **kwargs)
# Calculate gradient based on Eq. 6.6 of "Introduction to optimization
# and data fitting" g=JT * r, where J is jacobian and r is residual.
if self.materialize_jac:
# Calculate Jacobian and it's transpose based on the updated coeffs.
jac = self._jac_fun(params, *args, **kwargs)
jt = jac.T
# J^T.J is the gauss newton approximate hessian.
jtj = jt @ jac
gradient = jt @ residual
if self.geodesic:
hess_res = self._hess_res_fun(params, *args, **kwargs)
else:
hess_res = None
else:
jt = None
jtj = None
hess_res = None
gradient = self._jt_op(params, residual, *args, **kwargs)
damping_factor = damping_factor * jax.lax.max(1 / 3, 1 -
(2 * gain_ratio - 1)**3)
increase_factor = 2
return params, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux
def gain_ratio_test_false_func(params, damping_factor,
increase_factor, residual, gradient, jt, jtj,
hess_res, updated_params, aux):
damping_factor = damping_factor * increase_factor
increase_factor = 2 * increase_factor
return params, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux
# Calling the jax condition function:
# Note that only the parameters that are used in the rest of the program
# are returned by this function.
gain_ratio_test_is_met = jnp.logical_and(gain_ratio > 0.0,
contribution_ratio_diff <= 0.0)
gain_ratio_test_is_met_ret = gain_ratio_test_true_func(
*gain_ratio_test_init_state)
gain_ratio_test_not_met_ret = gain_ratio_test_false_func(
*gain_ratio_test_init_state)
gain_ratio_test_is_met_ret = jax.tree_map(
lambda x: gain_ratio_test_is_met * x, gain_ratio_test_is_met_ret)
gain_ratio_test_not_met_ret = jax.tree_map(
lambda x: (1.0 - gain_ratio_test_is_met) * x,
gain_ratio_test_not_met_ret)
params, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux = jax.tree_map(
lambda x, y: x + y, gain_ratio_test_is_met_ret,
gain_ratio_test_not_met_ret)
return params, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux
[docs] def update_state_using_delta_params(self, loss_curr, params, delta_params,
contribution_ratio_diff, damping_factor,
increase_factor, residual, gradient, jt,
jtj, hess_res, aux, *args, **kwargs):
"""The function to return state variables based on delta_params.
Define the functions required for the major conditional of the algorithm,
which checks the magnitude of dparams and checks if it is small enough.
for the value of dparams.
"""
updated_params = params + delta_params
residual_next = self._fun(updated_params, *args, **kwargs)
# Calculate denominator of the gain ratio based on Eq. 6.16, "Introduction
# to optimization and data fitting", L(0)-L(hlm)=0.5*hlm^T*(mu*hlm-g).
gain_ratio_denom = 0.5 * delta_params.T @ (
damping_factor * delta_params - gradient)
# Current value of loss function F=0.5*||f||^2.
loss_next = 0.5 * (residual_next @ residual_next)
gain_ratio = (loss_curr - loss_next) / gain_ratio_denom
gain_ratio_test_init_state = (params, damping_factor, increase_factor,
residual, gradient, jt, jtj, hess_res,
updated_params, aux)
# Calling the jax condition function:
# Note that only the parameters that are used in the rest of the program
# are returned by this function.
params, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux = (
self.update_state_using_gain_ratio(gain_ratio, contribution_ratio_diff,
gain_ratio_test_init_state, *args,
**kwargs))
return params, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux
[docs] def update(self, params, state: NamedTuple, *args, **kwargs) -> base.OptStep:
"""Performs one iteration of the least-squares solver.
Args:
params: pytree containing the parameters.
state: named tuple containing the solver state.
Returns:
(params, state)
"""
# Current value of the loss function F=0.5*||f||^2.
loss_curr = state.loss
# For geodesic acceleration, we calculate jtrpp=JT * r",
# where J is jacobian and r" is second order directional derivative.
if self.materialize_jac:
damping_term = state.damping_factor * jnp.identity(params.size)
# Note that instead of taking the inverse of jtj_corr and multiply that
# by state.gradient, we are using `solve`, which uses LU decomposition of
# jtj_corr and uses that to obtain velocity. This has the advantage of
# lower number of floating point operations and therefore less numerical
# error which can be helpful for the case of single precision arithmatics.
jtj_corr = state.jtj + damping_term
velocity = jnp.linalg.solve(jtj_corr, state.gradient)
delta_params = velocity
if self.geodesic:
rpp = (state.hess_res @ velocity) @ velocity
# Note the same as above here that we could use inverse of jtj_corr but
# chose to use solve for higher performance and lower numerical error.
acceleration = jnp.linalg.solve(jtj_corr, state.jt)
acceleration = acceleration @ rpp
delta_params += 0.5*acceleration
else:
matvec = lambda v: self._jtj_op(params, v, *args, **kwargs)
if isinstance(self.solver, Callable):
velocity = self.solver(
matvec, state.gradient, ridge=state.damping_factor, init=state.delta)
delta_params = velocity
if self.geodesic:
rpp = self._d2fvv_op(params, velocity, velocity, *args, **kwargs)
jtrpp = self._jt_op(params, rpp, *args, **kwargs)
acceleration = self.solver(matvec, jtrpp, ridge=state.damping_factor)
delta_params += 0.5*acceleration
elif self.solver == 'cholesky':
velocity = solve_cholesky(matvec, state.gradient, ridge=state.damping_factor)
delta_params = velocity
if self.geodesic:
rpp = self._d2fvv_op(params, velocity, velocity, *args, **kwargs)
jtrpp = self._jt_op(params, rpp, *args, **kwargs)
acceleration = solve_cholesky(matvec, jtrpp, ridge=state.damping_factor)
delta_params += 0.5*acceleration
elif self.solver == 'inv':
velocity = solve_inv(
matvec, state.gradient, ridge=state.damping_factor)
delta_params = velocity
if self.geodesic:
rpp = self._d2fvv_op(params, velocity, velocity, *args, **kwargs)
jtrpp = self._jt_op(params, rpp, *args, **kwargs)
acceleration = solve_inv(matvec, jtrpp, ridge=state.damping_factor)
delta_params += 0.5*acceleration
if self.geodesic:
contribution_ratio_diff = jnp.linalg.norm(acceleration) / jnp.linalg.norm(
velocity) - self.contribution_ratio_threshold
else:
contribution_ratio_diff = 0.0
delta_params = -delta_params
# Checking if the dparams satisfy the "sufficiently small" criteria.
params, damping_factor, increase_factor, residual, gradient, jt, jtj, hess_res, aux = (
self.update_state_using_delta_params(loss_curr, params, delta_params,
contribution_ratio_diff,
state.damping_factor,
state.increase_factor,
state.residual, state.gradient,
state.jt, state.jtj,
state.hess_res, state.aux, *args, **kwargs))
state = LevenbergMarquardtState(
iter_num=state.iter_num + 1,
damping_factor=damping_factor,
increase_factor=increase_factor,
error=tree_l2_norm(gradient),
residual=residual,
loss=0.5 * (residual @ residual),
delta=delta_params,
gradient=gradient,
jt=jt,
jtj=jtj,
hess_res=hess_res,
aux=aux)
return base.OptStep(params=params, state=state)
def __post_init__(self):
if self.has_aux:
self._fun_with_aux = self.residual_fun
self._fun = lambda *a, **kw: self._fun_with_aux(*a, **kw)[0]
else:
self._fun = self.residual_fun
self._fun_with_aux = lambda *a, **kw: (self.residual_fun(*a, **kw), None)
# For geodesic acceleration, we define Hessian of the residual function.
if self.materialize_jac and self.jac_fun is None:
self._jac_fun = jax.jacfwd(self._fun, argnums=(0))
if self.geodesic:
self._hess_res_fun = jax.jacfwd(
jax.jacfwd(self._fun, argnums=(0)), argnums=(0))
elif not self.materialize_jac and self.jac_fun:
self._jac_fun = self.jac_fun
if self.geodesic:
self._hess_res_fun = jax.jacfwd(self.jac_fun, argnums=(0))
[docs] def optimality_fun(self, params, *args, **kwargs):
"""Optimality function mapping compatible with ``@custom_root``."""
residual = self._fun(params, *args, **kwargs)
return self._jt_op(params, residual, *args, **kwargs)
def _jt_op(self, params, residual, *args, **kwargs):
"""Product of J^T and residual -- J: jacobian of fun at params."""
fun_with_args = lambda p: self._fun(p, *args, **kwargs)
_, vjpfun = jax.vjp(fun_with_args, params)
jt_op_val, = vjpfun(residual)
return jt_op_val
def _jtj_op(self, params, vec, *args, **kwargs):
"""Product of J^T.J with vec using vjp & jvp, where J is jacobian of fun at params."""
fun_with_args = lambda p: self._fun(p, *args, **kwargs)
_, vjpfun = jax.vjp(fun_with_args, params)
_, jvp_val = jax.jvp(fun_with_args, (params,), (vec,))
jtj_op_val, = vjpfun(jvp_val)
return jtj_op_val
def _jtj_diag_op(self, params, *args, **kwargs):
"""Diagonal elements of J^T.J, where J is jacobian of fun at params."""
diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs)
return jax.vmap(diag_op)(jnp.eye(len(params))).T
def _d2fvv_op(self, primals, tangents1, tangents2, *args, **kwargs):
"""Product with d2f.v1v2."""
fun_with_args = lambda p: self._fun(p, *args, **kwargs)
g = lambda pr: jax.jvp(fun_with_args, (pr,), (tangents1,))[1]
return jax.jvp(g, (primals,), (tangents2,))[1]
def print_header():
print(f"{'Iteration':^15}{'Cost':^15}{'||Gradient||':^15}{'Damping Factor':^15}")
def print_iteration(state: LevenbergMarquardtState):
print(f"{state.iter_num:^15}{state.loss:^15.4e}{state.error:^15.4e}{state.damping_factor:^15.4}")