Root finding

This section is concerned with root finding, that is finding \(x\) such that \(F(x, \theta) = 0\).

Bisection

jaxopt.Bisection(optimality_fun, lower, upper)

One-dimensional root finding using bisection.

Bisection is a suitable algorithm when \(F(x, \theta)\) is one-dimensional in \(x\).

Instantiating and running the solver

First, let us consider the case \(F(x)\), i.e., without extra argument \(\theta\). The Bisection class requires a bracketing interval \([\text{lower}, \text{upper}]\) such that \(F(\text{lower})\) and \(F(\text{upper})\) have opposite signs, meaning that a root is contained in this interval as long as \(F\) is continuous. For instance, suppose that we want to find the root of \(F(x) = x^3 - x - 2\). We have \(F(1) = -2\) and \(F(2) = 4\). Since the function is continuous, there must be a \(x\) between -2 and 4 such that \(F(x) = 0\):

from jaxopt import Bisection

def F(x):
  return x ** 3 - x - 2

bisec = Bisection(optimality_fun=F, lower=1, upper=2)
print(bisec.run().params)

Bisection successfully finds the root x = 1.521. Notice that Bisection does not require an initialization, since the bracketing interval is sufficient.

Differentiation

Now, let us consider the case \(F(x, \theta)\). For instance, suppose that F takes an additional argument factor. We can easily differentiate with respect to factor:

def F(x, factor):
  return factor * x ** 3 - x - 2

def root(factor):
  bisec = Bisection(optimality_fun=F, lower=1, upper=2)
  return bisec.run(factor=factor).params

# Derivative of root with respect to factor at 2.0.
print(jax.grad(root)(2.0))

Under the hood, we use the implicit function theorem in order to differentiate the root. See the implicit differentiation section for more details.

Scipy wrapper

jaxopt.ScipyRootFinding([method, dtype, ...])

scipy.optimize.root wrapper.

Broyden’s method

jaxopt.Broyden(fun[, has_aux, maxiter, tol, ...])

Limited-memory Broyden solver.

Broyden’s method is an iterative algorithm suitable for nonlinear root equations in any dimension. It is a quasi-Newton method (like L-BFGS), meaning that it uses an approximation of the Jacobian matrix at each iteration. The approximation is updated at each iteration with a rank-one update. This makes the approximation easy to invert using the Sherman-Morrison formula, provided that it does not use too many updates. One can control the number of updates with the history_size argument. Furthermore, Broyden’s method uses a line search to ensure the rank-one updates are stable.

Example:

import jax.numpy as jnp
from jaxopt import Broyden

def F(x):
  return x ** 3 - x - 2

broyden = Broyden(fun=F)
print(broyden.run(jnp.array(1.0)).params)

For implicit differentiation:

import jax
import jax.numpy as jnp
from jaxopt import Broyden

def F(x, factor):
  return factor * x ** 3 - x - 2

def root(factor):
  broyden = Broyden(fun=F)
  return broyden.run(jnp.array(1.0), factor=factor).params

# Derivative of root with respect to factor at 2.0.
print(jax.grad(root)(2.0))