Constrained optimization
This section is concerned with problems of the form
where \(f(x, \theta)\) is differentiable (almost everywhere), \(x\) are the parameters with respect to which the function is minimized, \(\theta\) are optional additional arguments, \(\mathcal{C}(\upsilon)\) is a convex set and \(\upsilon\) are parameter the convex set may depend on.
Projected gradient
|
Projected gradient solver. |
Instantiating and running the solver
To solve constrained optimization problems, we can use projected gradient
descent, which is gradient descent with an additional projection onto the
constraint set. Constraints are specified by setting the projection
argument. For instance, non-negativity constraints can be specified using
projection_non_negative
:
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative
pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
pg_sol = pg.run(w_init, data=(X, y)).params
Numerous projections are available, see below.
Specifying projection parameters
Some projections have a hyperparameter that can be specified. For
instance, the hyperparameter of projection_l2_ball
is the radius of the \(L_2\) ball.
This can be passed using the hyperparams_proj
argument of run
:
from jaxopt.projection import projection_l2_ball
radius = 1.0
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
Examples
Differentiation
In some applications, it is useful to differentiate the solution of the solver
with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. radius
:
def solution(radius):
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
print(jax.jacobian(solution)(radius))
Under the hood, we use the implicit function theorem if implicit_diff=True
and autodiff of unrolled iterations if implicit_diff=False
. See the
implicit differentiation section for more details.
Projections
The Euclidean projection onto \(\mathcal{C}(\upsilon)\) is:
The following operators are available.
Projection onto the non-negative orthant: |
|
|
Projection onto box constraints: |
|
Projection onto a simplex: |
Projection onto the simplex with cardinality constraint (maximum number of non-zero elements). |
|
|
Projection onto the l1 sphere: |
|
Projection onto the l1 ball: |
|
Projection onto the l2 sphere: |
|
Projection onto the l2 ball: |
|
Projection onto the l-infinity ball: |
Projection onto a hyperplane: |
|
Projection onto a halfspace: |
|
Projection onto an affine set: |
|
Projection onto a polyhedron: |
|
Projection onto a box section: |
|
Projection onto the transportation polytope. |
|
|
Projection onto the Birkhoff polytope, the set of doubly stochastic matrices. |
Projections always have two arguments: the input to be projected and the parameters of the convex set.
Mirror descent
|
Mirror descent solver. |
Kullback-Leibler projections
The Kullback-Leibler projection onto \(\mathcal{C}(\upsilon)\) is:
The following operators are available.
Kullback-Leibler projection onto the transportation polytope. |
|
Kullback-Leibler projection onto the Birkhoff polytope, the set of doubly stochastic matrices. |
Box constraints
For optimization with box constraints, in addition to projected gradient descent, we can use our SciPy wrapper.
|
scipy.optimize.minimize wrapper. |
|
L-BFGS-B solver. |
This example shows how to apply non-negativity constraints, which can be achieved by setting box constraints \([0, \infty)\):
from jaxopt import ScipyBoundedMinimize
w_init = jnp.zeros(n_features)
lbfgsb = ScipyBoundedMinimize(fun=fun, method="l-bfgs-b")
lower_bounds = jnp.zeros_like(w_init)
upper_bounds = jnp.ones_like(w_init) * jnp.inf
bounds = (lower_bounds, upper_bounds)
lbfgsb_sol = lbfgsb.run(w_init, bounds=bounds, data=(X, y)).params