jaxopt.perturbations.make_perturbed_argmax

jaxopt.perturbations.make_perturbed_argmax(argmax_fun, num_samples=1000, sigma=0.1, noise=<jaxopt._src.perturbations.Gumbel object>, control_variate=False)[source]

Transforms a function into a differentiable version with perturbations.

Parameters
  • argmax_fun (Callable[[Array], Array]) – the argmax function to transform into a differentiable version. Signature for argmax_fun currently supported for custom jvp and jit is: + input [D1, …, Dk], output [D1, …, Dk], k >= 1

  • num_samples (int) – an int, the number of perturbed outputs to average over.

  • sigma (float) – a float, the scale of the random perturbation.

  • noise – a distribution object that must implement a sample function and a log-pdf of the desired distribution, similar to the examples above. Default is Gumbel distribution.

  • control_variate (bool) – Boolean indicating whether a control variate is used in the Monte-Carlo estimate of the Jacobian.

Returns

A function with the same signature (and an rng) that can be differentiated.

Example

Given an argmax function such as:

def argmax_fun(x):
  return jax.nn.one_hot(jnp.argmax(x), x.shape[0])

pert_argmax_fun = make_perturbed_argmax(argmax_fun,
                                        num_samples=200,
                                        sigma=0.01)

Then pert_argmax_fun is differentiable, a perturbed version of argmax_fun. Since it relies on randomness, it requires an rng key:

pert_argmax = pert_argmax_fun(x, rng)

When handling a batched input, vmap can be applied to this function, with some care in splitting the rng key:

batch_size = x_batch.shape[0]
rngs_batch = jax.random.split(rng, batch_size)
pert_argmax_batch = jax.vmap(pert_argmax_fun)(x_batch, rngs_batch)

Further, if the argmax_fun is jittable, then so is pert_argmax_fun.