Constrained optimization

This section is concerned with problems of the form

\[\min_{x} f(x, \theta) \textrm{ subject to } x \in \mathcal{C}(\upsilon),\]

where \(f(x, \theta)\) is differentiable (almost everywhere), \(x\) are the parameters with respect to which the function is minimized, \(\theta\) are optional additional arguments, \(\mathcal{C}(\upsilon)\) is a convex set and \(\upsilon\) are parameter the convex set may depend on.

Projected gradient

jaxopt.ProjectedGradient(fun, projection[, ...])

Projected gradient solver.

Instantiating and running the solver

To solve constrained optimization problems, we can use projected gradient descent, which is gradient descent with an additional projection onto the constraint set. Constraints are specified by setting the projection argument. For instance, non-negativity constraints can be specified using projection_non_negative:

from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative

pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
pg_sol = pg.run(w_init, data=(X, y)).params

Numerous projections are available, see below.

Specifying projection parameters

Some projections have a hyperparameter that can be specified. For instance, the hyperparameter of projection_l2_ball is the radius of the \(L_2\) ball. This can be passed using the hyperparams_proj argument of run:

from jaxopt.projection import projection_l2_ball

radius = 1.0
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params

Differentiation

In some applications, it is useful to differentiate the solution of the solver with respect to some hyperparameters. Continuing the previous example, we can now differentiate the solution w.r.t. radius:

def solution(radius):
  pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
  return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params

print(jax.jacobian(solution)(radius))

Under the hood, we use the implicit function theorem if implicit_diff=True and autodiff of unrolled iterations if implicit_diff=False. See the implicit differentiation section for more details.

Projections

The Euclidean projection onto \(\mathcal{C}(\upsilon)\) is:

\[\text{proj}_{\mathcal{C}}(x', \upsilon) := \underset{x}{\text{argmin}} ~ ||x' - x||^2 \textrm{ subject to } x \in \mathcal{C}(\upsilon).\]

The following operators are available.

jaxopt.projection.projection_non_negative(x)

Projection onto the non-negative orthant:

jaxopt.projection.projection_box(x, hyperparams)

Projection onto box constraints:

jaxopt.projection.projection_simplex(x[, value])

Projection onto a simplex:

jaxopt.projection.projection_sparse_simplex(x, ...)

Projection onto the simplex with cardinality constraint (maximum number of non-zero elements).

jaxopt.projection.projection_l1_sphere(x[, ...])

Projection onto the l1 sphere:

jaxopt.projection.projection_l1_ball(x[, ...])

Projection onto the l1 ball:

jaxopt.projection.projection_l2_sphere(x[, ...])

Projection onto the l2 sphere:

jaxopt.projection.projection_l2_ball(x[, ...])

Projection onto the l2 ball:

jaxopt.projection.projection_linf_ball(x[, ...])

Projection onto the l-infinity ball:

jaxopt.projection.projection_hyperplane(x, ...)

Projection onto a hyperplane:

jaxopt.projection.projection_halfspace(x, ...)

Projection onto a halfspace:

jaxopt.projection.projection_affine_set(x, ...)

Projection onto an affine set:

jaxopt.projection.projection_polyhedron(x, ...)

Projection onto a polyhedron:

jaxopt.projection.projection_box_section(x, ...)

Projection onto a box section:

jaxopt.projection.projection_transport(...)

Projection onto the transportation polytope.

jaxopt.projection.projection_birkhoff(sim_matrix)

Projection onto the Birkhoff polytope, the set of doubly stochastic matrices.

Projections always have two arguments: the input to be projected and the parameters of the convex set.

Mirror descent

jaxopt.MirrorDescent(fun, projection_grad, ...)

Mirror descent solver.

Kullback-Leibler projections

The Kullback-Leibler projection onto \(\mathcal{C}(\upsilon)\) is:

\[\text{proj}_{\mathcal{C}}(x', \upsilon) := \underset{x}{\text{argmin}} ~ \text{KL}(x, \exp(x')) \textrm{ subject to } x \in \mathcal{C}(\upsilon).\]

The following operators are available.

jaxopt.projection.kl_projection_transport(...)

Kullback-Leibler projection onto the transportation polytope.

jaxopt.projection.kl_projection_birkhoff(...)

Kullback-Leibler projection onto the Birkhoff polytope, the set of doubly stochastic matrices.

Box constraints

For optimization with box constraints, in addition to projected gradient descent, we can use our SciPy wrapper.

jaxopt.ScipyBoundedMinimize([method, dtype, ...])

scipy.optimize.minimize wrapper.

jaxopt.LBFGSB(fun[, value_and_grad, ...])

L-BFGS-B solver.

This example shows how to apply non-negativity constraints, which can be achieved by setting box constraints \([0, \infty)\):

from jaxopt import ScipyBoundedMinimize

w_init = jnp.zeros(n_features)
lbfgsb = ScipyBoundedMinimize(fun=fun, method="l-bfgs-b")
lower_bounds = jnp.zeros_like(w_init)
upper_bounds = jnp.ones_like(w_init) * jnp.inf
bounds = (lower_bounds, upper_bounds)
lbfgsb_sol = lbfgsb.run(w_init, bounds=bounds, data=(X, y)).params