Fixed point resolution
This section is concerned with fixed-point resolution, that is finding \(x\) such that \(T(x, \theta) = x\).
Equivalence with root finding
If \(x\) is a fixed point of \(T\) then \(x\) is a root of \(F(x, \theta) = T(x, \theta) - x\). Reciprocally, if \(x\) is the root of some \(F(x, \theta)\) then it is also the fixed point of \(T(x, \theta) = F(x, \theta) + x\). Hence, root finding and fixed-point resolution are two different views of the same problem. This section is concerned with algorithms that are more naturally seen or are historically associated with fixed point resolution. The root finding viewpoint is discussed in this section.
Fixed point iterations
|
Fixed point iteration method. |
The fixed point iteration method simply consists in iterating \(x_{n+1}=T(x_n, \theta)\), which is guaranteed to converge to a fixed point if \(x\mapsto T(x,\theta)\) is a contractive map. See Banach fixed-point theorem for more details.
Code example:
from jaxopt import FixedPointIteration
def T(x, theta): # contractive map
return 0.5 * x + theta
fpi = FixedPointIteration(fixed_point_fun=T)
x_init = jnp.array(0.)
theta = jnp.array(0.5)
print(fpi.run(x_init, theta).params)
FixedPointIteration
successfully finds the fixed point x = 1
.
Differentiation
Fixed points can be differentiated with respect to \(\theta\):
from jaxopt import FixedPointIteration
def T(x, theta): # contractive map
return 0.5 * x + theta
fpi = FixedPointIteration(fixed_point_fun=T, implicit_diff=True)
x_init = jnp.array(0.)
theta = jnp.array(0.5)
def fixed_point(x, theta):
return fpi.run(x, theta).params
print(jax.grad(fixed_point, argnums=1)(x_init, theta)) # only gradient
print(jax.value_and_grad(fixed_point, argnums=1)(x_init, theta)) # both value and gradient
Note that \(x(\theta)=2\theta\) so \(\nabla x(\theta)=2\).
Under the hood, we use the implicit function theorem in order to differentiate the fixed point. See the implicit differentiation section for more details.
Anderson acceleration
|
Anderson acceleration. |
Anderson acceleration is an iterative method that aims to compute the next
iterate \(x_{n}\) as a linear combination of the \(m\) last iterates
\([x_{n-m},x_{n-m+1},\ldots x_{n-1}]\). The coefficients of the
linear combination are computed ‘on the fly’ at each iteration. As a result,
not only the convergence is faster but the convergence conditions are weakened,
allowing to tackle problems FixedPointIteration
could not. See Pollock
and Rebholz (2020) for more details.
The size of the history \(m\) (denoted history_size
below) plays a
crucial role in the method’s performance. A higher \(m\) could speed up the
convergence at the cost of higher memory consumption, and more numerical
instabilities. Those numerical instabilities can be mitigated by increasing
the ridge
regularization hyper-parameter.
Example:
from jaxopt import AndersonAcceleration
def T(x, theta): # contractive map
return 0.5 * x + theta
aa = AndersonAcceleration(fixed_point_fun=T, history_size=5,
ridge=1e-6, tol=1e-5)
x_init = jnp.array(0.)
theta = jnp.array(0.5)
print(aa.run(x0, theta).params)
For implicit differentiation:
from jaxopt import AndersonAcceleration
def T(x, theta): # contractive map
return 0.5 * x + theta
aa = AndersonAcceleration(fixed_point_fun=T, history_size=5,
ridge=1e-6, tol=1e-5, implicit_diff=True)
x_init = jnp.array(0.)
theta = jnp.array(0.5)
def fixed_point(x, theta):
return aa.run(x, theta).params
print(jax.grad(fixed_point, argnums=1)(x_init, theta)) # only gradient
print(jax.value_and_grad(fixed_point, argnums=1)(x_init, theta)) # both value and gradient
Accelerating JAXopt optimizers
Anderson acceleration can also be used to accelerate optimization algorithms.
To spare the user the burden of implementing Anderson acceleration for every
solver, we propose the AndersonWrapper
class. It directly takes an
optimizer as input and applies Anderson acceleration to its iterates.
|
Wrapper for accelerating JAXopt solvers. |
Its usage is transparent:
gd = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500, tol=1e-3)
aa = jaxopt.AndersonWrapper(solver=gd, history_size=5)
sol, aa_state = aa.run(init_params, l2reg=l2reg, X=X, y=y)
print(sol)