jaxopt.perturbations.make_perturbed_max
- jaxopt.perturbations.make_perturbed_max(argmax_fun, num_samples=1000, sigma=0.1, noise=<jaxopt._src.perturbations.Gumbel object>)[source]
Turns an argmax in a differentiable version of the max with perturbations.
- Parameters
argmax_fun (
Callable
[[Array
],Array
]) – the argmax function to transform into a differentiable version. Signature for argmax_fun currently supported for custom jvp and jit is: + input [D1, …, Dk], output [D1, …, Dk], k >= 1num_samples (
int
) – an int, the number of perturbed outputs to average over.sigma (
float
) – a float, the scale of the random perturbation.noise – a distribution object that must implement a sample function and a log-pdf of the desired distribution, similar to the examples above. Default is Gumbel distribution.
- Returns
A function with the same inputs (and an rng) that can be differentiated.
Example
Given an argmax function, such as:
def argmax_fun(x): return jax.nn.one_hot(jnp.argmax(x), x.shape[0]) pert_max_fun = make_perturbed_max(argmax_fun, num_samples=200, sigma=0.01)
Then pert_max_fun is differentiable, a perturbed version of the associated max to argmax_fun. Since it relies on randomness, it requires an rng key:
pert_max = pert_max_fun(x, rng)
When handling a batched input, vmap can be applied to this function, with some care in splitting the rng key:
batch_size = x_batch.shape[0] rngs_batch = jax.random.split(rng, batch_size) pert_max_batch = jax.vmap(pert_max_fun)(x_batch, rngs_batch)
Furthermore, if the argmax_fun is jittable, then so is pert_max_fun.