Implicit differentiation

Argmin differentiation

Argmin differentiation is the task of differentiating a minimization problem’s solution with respect to its inputs. Namely, given

\[x^\star(\theta) := \underset{x}{\text{argmin}} f(x, \theta),\]

we would like to compute the Jacobian \(\partial x^\star(\theta)\). This is usually done either by implicit differentiation or by autodiff through an algorithm’s unrolled iterates.

JAXopt solvers

All solvers in JAXopt support implicit differentiation out-of-the-box. Most solvers have an implicit_diff=True|False option. When set to False, autodiff of unrolled iterates is used instead of implicit differentiation.

Using the ridge regression example from the unconstrained optimization section, we can write:

def ridge_reg_objective(params, l2reg, X, y):
  residuals = jnp.dot(X, params) - y
  return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(params ** 2)

def ridge_reg_solution(l2reg, X, y):
  gd = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500, implicit_diff=True)
  return gd.run(init_params, l2reg=l2reg, X=X, y=y).params

Now, ridge_reg_solution is differentiable just like any other JAX function. Since ridge_reg_solution outputs a vector, we can compute its Jacobian:

print(jax.jacobian(ridge_reg_solution, argnums=0)(l2reg, X, y)

where argnums=0 specifies that we want to differentiate with respect to l2reg.

We can also compose ridge_reg_solution with other functions:

def validation_loss(l2reg):
  sol = ridge_reg_solution(l2reg, X_train, y_train)
  residuals = jnp.dot(X_val, params) - y_val
  return jnp.mean(residuals ** 2)

print(jax.grad(validation_loss)(l2reg))

Examples

Custom solvers

jaxopt.implicit_diff.custom_root(optimality_fun)

Decorator for adding implicit differentiation to a root solver.

jaxopt.implicit_diff.custom_fixed_point(...)

Decorator for adding implicit differentiation to a fixed point solver.

JAXopt also provides the custom_root and custom_fixed_point decorators, for easily adding implicit differentiation on top of any existing solver.

JVPs and VJPs

Finally, we also provide lower-level routines for computing the JVPs and VJPs of roots of functions.

jaxopt.implicit_diff.root_jvp(...[, solve])

Jacobian-vector product of a root.

jaxopt.implicit_diff.root_vjp(...[, solve])

Vector-Jacobian product of a root.

References: