Root finding
This section is concerned with root finding, that is finding \(x\) such that \(F(x, \theta) = 0\).
Bisection
|
One-dimensional root finding using bisection. |
Bisection is a suitable algorithm when \(F(x, \theta)\) is one-dimensional in \(x\).
Instantiating and running the solver
First, let us consider the case \(F(x)\), i.e., without extra argument
\(\theta\). The Bisection
class requires a bracketing interval
\([\text{lower}, \text{upper}]\) such that \(F(\text{lower})\) and
\(F(\text{upper})\) have opposite signs, meaning that a root is contained
in this interval as long as \(F\) is continuous. For instance, suppose
that we want to find the root of \(F(x) = x^3 - x - 2\). We have
\(F(1) = -2\) and \(F(2) = 4\). Since the function is continuous, there
must be a \(x\) between -2 and 4 such that \(F(x) = 0\):
from jaxopt import Bisection
def F(x):
return x ** 3 - x - 2
bisec = Bisection(optimality_fun=F, lower=1, upper=2)
print(bisec.run().params)
Bisection
successfully finds the root x = 1.521
.
Notice that Bisection
does not require an initialization,
since the bracketing interval is sufficient.
Differentiation
Now, let us consider the case \(F(x, \theta)\). For instance, suppose that
F
takes an additional argument factor
. We can easily differentiate
with respect to factor
:
def F(x, factor):
return factor * x ** 3 - x - 2
def root(factor):
bisec = Bisection(optimality_fun=F, lower=1, upper=2)
return bisec.run(factor=factor).params
# Derivative of root with respect to factor at 2.0.
print(jax.grad(root)(2.0))
Under the hood, we use the implicit function theorem in order to differentiate the root. See the implicit differentiation section for more details.
Scipy wrapper
|
scipy.optimize.root wrapper. |
Broyden’s method
|
Limited-memory Broyden solver. |
Broyden’s method is an iterative algorithm suitable for nonlinear root equations in any dimension.
It is a quasi-Newton method (like L-BFGS), meaning that it uses an approximation of the Jacobian matrix
at each iteration.
The approximation is updated at each iteration with a rank-one update.
This makes the approximation easy to invert using the Sherman-Morrison formula, provided that it does not use too many
updates.
One can control the number of updates with the history_size
argument.
Furthermore, Broyden’s method uses a line search to ensure the rank-one updates are stable.
Example:
import jax.numpy as jnp
from jaxopt import Broyden
def F(x):
return x ** 3 - x - 2
broyden = Broyden(fun=F)
print(broyden.run(jnp.array(1.0)).params)
For implicit differentiation:
import jax
import jax.numpy as jnp
from jaxopt import Broyden
def F(x, factor):
return factor * x ** 3 - x - 2
def root(factor):
broyden = Broyden(fun=F)
return broyden.run(jnp.array(1.0), factor=factor).params
# Derivative of root with respect to factor at 2.0.
print(jax.grad(root)(2.0))