Perturbed optimization

The perturbed optimization module allows to transform a non-smooth function such as a max or arg-max into a differentiable function using random perturbations. This is useful for optimization algorithms that require differentiability, such as gradient descent (e.g. see Notebook on perturbed optimizers).

Max perturbations

Consider a maximum function of the form:

\[F(\theta) = \max_{y \in \mathcal{C}} \langle y, \theta\rangle\,,\]

where \(\mathcal{C}\) is a convex set.

jaxopt.perturbations.make_perturbed_max(...)

Turns an argmax in a differentiable version of the max with perturbations.

The function jaxopt.perturbations.make_perturbed_max() transforms the function \(F\) into a the following differentiable function using random perturbations:

\[F_{\varepsilon}(\theta) = \mathbb{E}\left[ F(\theta + \varepsilon Z) \right]\,,\]

where \(Z\) is a random variable. The distribution of this random variable can be specified through the keyword argument noise. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate.

Argmax perturbations

Consider an arg-max function of the form:

\[y^*(\theta) = \mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta\rangle\,,\]

where \(\mathcal{C}\) is a convex set.

The function jaxopt.perturbations.make_perturbed_argmax() transforms the function \(y^\star\) into a the following differentiable function using random perturbations:

\[y_{\varepsilon}^*(\theta) = \mathbb{E}\left[ \mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta + \varepsilon Z \rangle \right]\,,\]

where \(Z\) is a random variable. The distribution of this random variable can be specified through the keyword argument noise. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate.

jaxopt.perturbations.make_perturbed_argmax(...)

Transforms a function into a differentiable version with perturbations.

Scalar perturbations

Consider any function, \(f\) that is not necessarily differentiable, e.g. piecewise-constant of the form:

\[f(\theta) = g(y^*(\theta))\,,\]

where \(\mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta\rangle\) and \(\mathcal{C}\) is a convex set.

The function jaxopt.perturbations.make_perturbed_fun() transforms the function \(f\) into a the following differentiable function using random perturbations:

\[f_{\varepsilon}(\theta) = \mathbb{E}\left[ f(\theta + \varepsilon Z) \right]\,,\]

where \(Z\) is a random variable. The distribution of this random variable can be specified through the keyword argument noise. The default is a Gumbel distribution, which is a good choice for discrete variables. For continuous variables, a normal distribution is more appropriate. This can be particulary useful in the example given above, when \(f\) is only defined on the discrete set, not its convex hull, i.e.

\[f_{\varepsilon}(\theta) = \mathbb{E}\left[ g(\mathop{\mathrm{arg\,max}}_{y \in \mathcal{C}} \langle y, \theta + \varepsilon Z \rangle) \right]\,,\]

jaxopt.perturbations.make_perturbed_fun(fun)

Transforms a function into a differentiable version with perturbations.

Noise distributions

The functions jaxopt.perturbations.make_perturbed_max(), jaxopt.perturbations.make_perturbed_argmax() and jaxopt.perturbations.make_perturbed_fun() take a keyword argument noise that specifies the distribution of random perturbations. Pre-defined distributions for this argument are the following:

jaxopt.perturbations.Normal()

Normal distribution.

jaxopt.perturbations.Gumbel()

Gumbel distribution.

References

Berthet, Q., Blondel, M., Teboul, O., Cuturi, M., Vert, J. P., & Bach, F. (2020). Learning with differentiable pertubed optimizers. Advances in neural information processing systems, 33.