Copyright 2022 Google LLC
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Perturbed optimizers
We review in this notebook a universal method to transform any optimizer \(y^*\) in a differentiable approximation \(y_\varepsilon^*\), using pertutbations following the method of Berthet et al. (2020). JAXopt provides an implementation that we illustrate here on some examples.
Concretely, for an optimizer function \(y^*\) defined by
we consider, for a random \(Z\) drawn from a distribution with continuous positive distribution \(\mu\)
%%capture
%pip install jaxopt
# activate TPUs if available
try:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
except KeyError:
print("TPU not found, continuing without it.")
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In [2], line 4
2 try:
3 import jax.tools.colab_tpu
----> 4 jax.tools.colab_tpu.setup_tpu()
5 except KeyError:
6 print("TPU not found, continuing without it.")
File ~/Desktop/projects/jax/jax/tools/colab_tpu.py:39, in setup_tpu(tpu_driver_version)
37 def setup_tpu(tpu_driver_version=None):
38 """Returns an error. Do not use."""
---> 39 raise RuntimeError(textwrap.dedent(message))
RuntimeError:
As of JAX 0.4.0, JAX only supports TPU VMs, not the older Colab TPUs.
We recommend trying Kaggle Notebooks
(https://www.kaggle.com/code, click on "New Notebook" near the top) which offer
TPU VMs. You have to create an account, log in, and verify your account to get
accelerator support.
Once you do that, there's a new "TPU 1VM v3-8" accelerator option. This gives
you a TPU notebook environment similar to Colab, but using the newer TPU VM
architecture. This should be a less buggy, more performant, and overall better
experience than the older TPU node architecture.
It is also possible to use Colab together with a self-hosted Jupyter kernel
running on a Cloud TPU VM. See
https://research.google.com/colaboratory/local-runtimes.html
for details.
import jax
import jax.numpy as jnp
import jaxopt
import time
from jaxopt import perturbations
Argmax one-hot
We consider an optimizer, such as the following argmax_one_hot
function. It transforms a real-valued vector into a binary vector with a 1 in the coefficient with largest magnitude and 0 elsewhere. It corresponds to \(y^*\) for \(\mathcal{C}\) being the unit simplex. We run it on an example input values
.
One-hot function
def argmax_one_hot(x, axis=-1):
return jax.nn.one_hot(jnp.argmax(x, axis=axis), x.shape[axis])
values = jnp.array([-0.6, 1.9, -0.2, 1.1, -1.0])
one_hot_vec = argmax_one_hot(values)
print(one_hot_vec)
[0. 1. 0. 0. 0.]
One-hot with pertubations
Our implementation transforms the argmax_one_hot
function into a perturbed one that we call pert_one_hot
. In this case we use Gumbel noise for the perturbation.
N_SAMPLES = 100_000
SIGMA = 0.5
GUMBEL = perturbations.Gumbel()
rng = jax.random.PRNGKey(1)
pert_one_hot = perturbations.make_perturbed_argmax(argmax_fun=argmax_one_hot,
num_samples=N_SAMPLES,
sigma=SIGMA,
noise=GUMBEL)
In this particular case, it is equal to the usual softmax function. This is not always true, in general there is no closed form for \(y_\varepsilon^*\)
rngs = jax.random.split(rng, 2)
rng = rngs[0]
pert_argmax = pert_one_hot(values, rng)
print(f'computation with {N_SAMPLES} samples, sigma = {SIGMA}')
print(f'perturbed argmax = {pert_argmax}')
jax.nn.softmax(values/SIGMA)
soft_max = jax.nn.softmax(values/SIGMA)
print(f'softmax = {soft_max}')
print(f'square norm of softmax = {jnp.linalg.norm(soft_max):.2e}')
print(f'square norm of difference = {jnp.linalg.norm(pert_argmax - soft_max):.2e}')
computation with 100000 samples, sigma = 0.5
perturbed argmax = [0.0055 0.81842 0.01212 0.16157 0.00239]
softmax = [0.00549293 0.8152234 0.01222475 0.16459078 0.00246813]
square norm of softmax = 8.32e-01
square norm of difference = 4.40e-03
Gradients for one-hot with perturbations
The perturbed optimizer \(y_\varepsilon^*\) is differentiable, and its gradient can be computed with stochastic estimation automatically, using jax.grad
.
We create a scalar loss loss_simplex
of the perturbed optimizer \(y^*_\varepsilon\)
For values
equal to a vector \(\theta\), we can compute gradients of
with respect to values
, automatically.
# Example loss function
def loss_simplex(values, rng):
n = values.shape[0]
v_true = jnp.arange(n) + 2
y_true = v_true / jnp.sum(v_true)
y_pred = pert_one_hot(values, rng)
return jnp.sum((y_true - y_pred) ** 2)
loss_simplex(values, rngs[1])
Array(0.5865911, dtype=float32)
We can compute the gradient of \(\ell\) directly
The computation of the jacobian \(\partial_\theta y^*_\varepsilon(\theta)\) is implemented automatically, using an estimation method given by Berthet et al. (2020), [Prop. 3.1].
# Gradient of the loss w.r.t input values
gradient = jax.grad(loss_simplex)(values, rngs[1])
print(gradient)
[-0.02052322 0.46736273 -0.02747887 -0.39873555 -0.00571656]
We illustrate the use of this method by running 200 steps of gradient descent on \(\theta_t\) so that it minimizes this loss.
# Doing 200 steps of gradient descent on the values to have the desired ranks
steps = 200
values_t = values
eta = 0.5
grad_func = jax.jit(jax.grad(loss_simplex))
for t in range(steps):
rngs = jax.random.split(rngs[1], 2)
values_t = values_t - eta * grad_func(values_t, rngs[1])
rngs = jax.random.split(rngs[1], 2)
n = values.shape[0]
v_true = jnp.arange(n) + 2
y_true = v_true / jnp.sum(v_true)
print(f'initial values = {values}')
print(f'initial one-hot = {argmax_one_hot(values)}')
print(f'initial diff. one-hot = {pert_one_hot(values, rngs[0])}')
print()
print(f'values after GD = {values_t}')
print(f'ranks after GD = {argmax_one_hot(values_t)}')
print(f'diff. one-hot after GD = {pert_one_hot(values_t, rngs[1])}')
print(f'target diff. one-hot = {y_true}')
initial values = [-0.6 1.9 -0.2 1.1 -1. ]
initial one-hot = [0. 1. 0. 0. 0.]
initial diff. one-hot = [0.0056 0.81554997 0.01239 0.16398999 0.00247 ]
values after GD = [-0.07073233 0.13270897 0.2768847 0.38671777 0.4782017 ]
ranks after GD = [0. 0. 0. 0. 1.]
diff. one-hot after GD = [0.09843 0.15089999 0.19826 0.25197 0.30043998]
target diff. one-hot = [0.1 0.15 0.2 0.25 0.3 ]
Differentiable ranking
Ranking function
We consider an optimizer, such as the following ranking
function. It transforms a real-valued vector of size \(n\) into a vector with coefficients being a permutation of \(\{0,\ldots, n-1\}\) corresponding to the order of the coefficients of the original vector. It corresponds to \(y^*\) for \(\mathcal{C}\) being the permutahedron. We run it on an example input values
.
# Function outputting a vector of ranks
def ranking(values):
return jnp.argsort(jnp.argsort(values))
# Example on random values
n = 6
rng = jax.random.PRNGKey(0)
values = jax.random.normal(rng, (n,))
print(f'values = {values}')
print(f'ranking = {ranking(values)}')
values = [ 0.18784384 -1.2833426 0.6494181 1.2490593 0.24447003 -0.11744965]
ranking = [2 0 4 5 3 1]
Ranking with perturbations
As above, our implementation transforms this function into a perturbed one that we call pert_ranking
. In this case we use Gumbel noise for the perturbation.
N_SAMPLES = 100
SIGMA = 0.2
GUMBEL = perturbations.Gumbel()
pert_ranking = perturbations.make_perturbed_argmax(ranking,
num_samples=N_SAMPLES,
sigma=SIGMA,
noise=GUMBEL)
# Expectation of the perturbed ranks on these values
rngs = jax.random.split(rng, 2)
diff_ranks = pert_ranking(values, rngs[0])
print(f'values = {values}')
print(f'diff_ranks = {diff_ranks}')
values = [ 0.18784384 -1.2833426 0.6494181 1.2490593 0.24447003 -0.11744965]
diff_ranks = [2.37 0.02 3.85 4.96 2.4099998 1.39 ]
Gradients for ranking with perturbations
As above, the perturbed optimizer \(y_\varepsilon^*\) is differentiable, and its gradient can be computed with stochastic estimation automatically, using jax.grad
.
We showcase this on a loss of \(y_\varepsilon(\theta)\) that can be directly differentiated w.r.t. the values
equal to \(\theta\).
# Example loss function
def loss_example(values, rng):
n = values.shape[0]
y_true = ranking(jnp.arange(n))
y_pred = pert_ranking(values, rng)
return jnp.sum((y_true - y_pred) ** 2)
print(loss_example(values, rngs[1]))
28.9336
# Gradient of the objective w.r.t input values
gradient = jax.grad(loss_example)(values, rngs[1])
print(gradient)
[ 14.139462 -2.6558158 -19.498537 16.295418 -2.338868 -21.901724 ]
As above, we showcase this example on gradient descent to minimize this loss.
steps = 20
values_t = values
eta = 0.1
grad_func = jax.jit(jax.grad(loss_example))
for t in range(steps):
rngs = jax.random.split(rngs[1], 2)
values_t = values_t - eta * grad_func(values_t, rngs[1])
rngs = jax.random.split(rngs[1], 2)
y_true = ranking(jnp.arange(n))
print(f'initial values = {values}')
print(f'initial ranks = {ranking(values)}')
print(f'initial diff. ranks = {pert_ranking(values, rngs[0])}')
print()
print(f'values after GD = {values_t}')
print(f'ranks after GD = {ranking(values_t)}')
print(f'diff. ranks after GD = {pert_ranking(values_t, rngs[1])}')
print(f'target diff. ranks = {y_true}')
initial values = [ 0.18784384 -1.2833426 0.6494181 1.2490593 0.24447003 -0.11744965]
initial ranks = [2 0 4 5 3 1]
initial diff. ranks = [2.44 0. 3.79 4.98 2.51 1.28]
values after GD = [-2.9923885 -1.9453204 -1.259742 -0.69805354 0.33311206 1.7650208 ]
ranks after GD = [0 1 2 3 4 5]
diff. ranks after GD = [0. 1.01 2.05 2.95 3.99 5. ]
target diff. ranks = [0 1 2 3 4 5]