Perturbed optimization
The perturbed optimization module allows to transform a non-smooth function such as a max or arg-max into a differentiable function using random perturbations. This is useful for optimization algorithms that require differentiability, such as gradient descent (e.g. see Notebook on perturbed optimizers).
Max perturbations
Consider a maximum function of the form:
where \(\mathcal{C}\) is a convex set.
Turns an argmax in a differentiable version of the max with perturbations. |
The function jaxopt.perturbations.make_perturbed_max()
transforms the function \(F\) into a the following differentiable function using random perturbations:
where \(Z\) is a random variable. The distribution of this random variable can be specified through the keyword argument noise
. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate.
Argmax perturbations
Consider an arg-max function of the form:
where \(\mathcal{C}\) is a convex set.
The function jaxopt.perturbations.make_perturbed_argmax()
transforms the function \(y^\star\) into a the following differentiable function using random perturbations:
where \(Z\) is a random variable. The distribution of this random variable can be specified through the keyword argument noise
. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate.
Transforms a function into a differentiable version with perturbations. |
Scalar perturbations
Consider any function, \(f\) that is not necessarily differentiable, e.g. piecewise-constant of the form:
where \(\mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta\rangle\) and \(\mathcal{C}\) is a convex set.
The function jaxopt.perturbations.make_perturbed_fun()
transforms the function \(f\) into a the following differentiable function using random perturbations:
where \(Z\) is a random variable. The distribution of this random variable can be specified through the keyword argument noise
. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate. This can be particulary useful in the example given above, when \(f\) is only defined on the discrete set, not its convex hull, i.e.
Transforms a function into a differentiable version with perturbations. |
Noise distributions
The functions jaxopt.perturbations.make_perturbed_max()
, jaxopt.perturbations.make_perturbed_argmax()
and jaxopt.perturbations.make_perturbed_fun()
take a keyword argument noise
that specifies the distribution of random perturbations. Pre-defined distributions for this argument are the following:
Normal distribution. |
|
Gumbel distribution. |
References
Berthet, Q., Blondel, M., Teboul, O., Cuturi, M., Vert, J. P., & Bach, F. (2020). Learning with differentiable pertubed optimizers. Advances in neural information processing systems, 33.