Fixed point resolution

This section is concerned with fixed-point resolution, that is finding \(x\) such that \(T(x, \theta) = x\).

Equivalence with root finding

If \(x\) is a fixed point of \(T\) then \(x\) is a root of \(F(x, \theta) = T(x, \theta) - x\). Reciprocally, if \(x\) is the root of some \(F(x, \theta)\) then it is also the fixed point of \(T(x, \theta) = F(x, \theta) + x\). Hence, root finding and fixed-point resolution are two different views of the same problem. This section is concerned with algorithms that are more naturally seen or are historically associated with fixed point resolution. The root finding viewpoint is discussed in this section.

Fixed point iterations

jaxopt.FixedPointIteration(fixed_point_fun)

Fixed point iteration method.

The fixed point iteration method simply consists in iterating \(x_{n+1}=T(x_n, \theta)\), which is guaranteed to converge to a fixed point if \(x\mapsto T(x,\theta)\) is a contractive map. See Banach fixed-point theorem for more details.

Code example:

from jaxopt import FixedPointIteration

def T(x, theta):  # contractive map
  return 0.5 * x + theta

fpi = FixedPointIteration(fixed_point_fun=T)
x_init = jnp.array(0.)
theta = jnp.array(0.5)
print(fpi.run(x_init, theta).params)

FixedPointIteration successfully finds the fixed point x = 1.

Differentiation

Fixed points can be differentiated with respect to \(\theta\):

from jaxopt import FixedPointIteration

def T(x, theta):  # contractive map
  return 0.5 * x + theta

fpi = FixedPointIteration(fixed_point_fun=T, implicit_diff=True)
x_init = jnp.array(0.)
theta = jnp.array(0.5)

def fixed_point(x, theta):
  return fpi.run(x, theta).params

print(jax.grad(fixed_point, argnums=1)(x_init, theta))  # only gradient
print(jax.value_and_grad(fixed_point, argnums=1)(x_init, theta))  # both value and gradient

Note that \(x(\theta)=2\theta\) so \(\nabla x(\theta)=2\).

Under the hood, we use the implicit function theorem in order to differentiate the fixed point. See the implicit differentiation section for more details.

Anderson acceleration

jaxopt.AndersonAcceleration(fixed_point_fun)

Anderson acceleration.

Anderson acceleration is an iterative method that aims to compute the next iterate \(x_{n}\) as a linear combination of the \(m\) last iterates \([x_{n-m},x_{n-m+1},\ldots x_{n-1}]\). The coefficients of the linear combination are computed ‘on the fly’ at each iteration. As a result, not only the convergence is faster but the convergence conditions are weakened, allowing to tackle problems FixedPointIteration could not. See Pollock and Rebholz (2020) for more details.

The size of the history \(m\) (denoted history_size below) plays a crucial role in the method’s performance. A higher \(m\) could speed up the convergence at the cost of higher memory consumption, and more numerical instabilities. Those numerical instabilities can be mitigated by increasing the ridge regularization hyper-parameter.

Example:

from jaxopt import AndersonAcceleration

def T(x, theta):  # contractive map
  return 0.5 * x + theta

aa = AndersonAcceleration(fixed_point_fun=T, history_size=5,
                          ridge=1e-6, tol=1e-5)
x_init = jnp.array(0.)
theta = jnp.array(0.5)
print(aa.run(x0, theta).params)

For implicit differentiation:

from jaxopt import AndersonAcceleration

def T(x, theta):  # contractive map
  return 0.5 * x + theta

aa = AndersonAcceleration(fixed_point_fun=T, history_size=5,
                          ridge=1e-6, tol=1e-5, implicit_diff=True)
x_init = jnp.array(0.)
theta = jnp.array(0.5)

def fixed_point(x, theta):
  return aa.run(x, theta).params

print(jax.grad(fixed_point, argnums=1)(x_init, theta))  # only gradient
print(jax.value_and_grad(fixed_point, argnums=1)(x_init, theta))  # both value and gradient

Accelerating JAXopt optimizers

Anderson acceleration can also be used to accelerate optimization algorithms. To spare the user the burden of implementing Anderson acceleration for every solver, we propose the AndersonWrapper class. It directly takes an optimizer as input and applies Anderson acceleration to its iterates.

jaxopt.AndersonWrapper(solver[, ...])

Wrapper for accelerating JAXopt solvers.

Its usage is transparent:

gd = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500, tol=1e-3)
aa = jaxopt.AndersonWrapper(solver=gd, history_size=5)
sol, aa_state = aa.run(init_params, l2reg=l2reg, X=X, y=y)
print(sol)