Implicit differentiation
Argmin differentiation
Argmin differentiation is the task of differentiating a minimization problem’s solution with respect to its inputs. Namely, given
we would like to compute the Jacobian \(\partial x^\star(\theta)\). This is usually done either by implicit differentiation or by autodiff through an algorithm’s unrolled iterates.
JAXopt solvers
All solvers in JAXopt support implicit differentiation out-of-the-box.
Most solvers have an implicit_diff=True|False
option. When set to False
,
autodiff of unrolled iterates is used instead of implicit differentiation.
Using the ridge regression example from the unconstrained optimization section, we can write:
def ridge_reg_objective(params, l2reg, X, y):
residuals = jnp.dot(X, params) - y
return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(params ** 2)
def ridge_reg_solution(l2reg, X, y):
gd = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500, implicit_diff=True)
return gd.run(init_params, l2reg=l2reg, X=X, y=y).params
Now, ridge_reg_solution
is differentiable just like any other JAX function.
Since ridge_reg_solution
outputs a vector, we can compute its Jacobian:
print(jax.jacobian(ridge_reg_solution, argnums=0)(l2reg, X, y)
where argnums=0
specifies that we want to differentiate with respect to l2reg
.
We can also compose ridge_reg_solution
with other functions:
def validation_loss(l2reg):
sol = ridge_reg_solution(l2reg, X_train, y_train)
residuals = jnp.dot(X_val, params) - y_val
return jnp.mean(residuals ** 2)
print(jax.grad(validation_loss)(l2reg))
Examples
sphx_glr_auto_examples_implicit_diff_plot_dataset_distillation.py
Custom solvers
|
Decorator for adding implicit differentiation to a root solver. |
Decorator for adding implicit differentiation to a fixed point solver. |
JAXopt also provides the custom_root
and custom_fixed_point
decorators,
for easily adding implicit differentiation on top of any existing solver.
JVPs and VJPs
Finally, we also provide lower-level routines for computing the JVPs and VJPs of roots of functions.
|
Jacobian-vector product of a root. |
|
Vector-Jacobian product of a root. |
References:
Efficient and Modular Implicit Differentiation, Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert. ArXiv preprint.