jaxopt.perturbations.make_perturbed_fun
- jaxopt.perturbations.make_perturbed_fun(fun, num_samples=1000, sigma=0.1, noise=<jaxopt._src.perturbations.Gumbel object>)[source]
Transforms a function into a differentiable version with perturbations.
- Parameters
fun (
Callable
[[Array
],float
]) – the function to transform into a differentiable version. Signature for fun currently supported for custom jvp and jit is: + input [D1, …, Dk], output [R_1, …, R_r]num_samples (
int
) – an int, the number of perturbed outputs to average over.sigma (
float
) – a float, the scale of the random perturbation.noise – a distribution object that must implement a sample function and a log-pdf of the desired distribution, similar to the examples above. Default is Gumbel distribution.
- Returns
A function with the same signature (and an rng) that can be differentiated.
Example
Given an argmax function such as:
def fun(x): return jax.nn.relu(x) pert_fun = make_perturbed_fun(fun, num_samples=200, sigma=0.01)
Then pert_fun is differentiable, a perturbed version of fun. Since it relies on randomness, it requires an rng key:
pert_output = pert_fun(x, rng)
When handling a batched input, vmap can be applied to this function, with some care in splitting the rng key:
batch_size = x_batch.shape[0] rngs_batch = jax.random.split(rng, batch_size) pert_batch = jax.vmap(pert_fun)(x_batch, rngs_batch)
Further, if fun is jittable, then so is pert_fun.