Non-smooth optimization

This section is concerned with problems of the form

\[\min_{x} f(x, \theta) + g(x, \lambda)\]

where \(f(x, \theta)\) is differentiable (almost everywhere), \(x\) are the parameters with respect to which the function is minimized, \(\theta\) are optional extra arguments, \(g(x, \lambda)\) is possibly non-smooth, and \(\lambda\) are extra parameters \(g\) may depend on.

Proximal gradient

jaxopt.ProximalGradient(fun[, prox, ...])

Proximal gradient solver.

Instantiating and running the solver

Proximal gradient is a generalization of projected gradient descent. The non-smooth term \(g\) above is specified by setting the corresponding proximal operator, which is achieved using the prox attribute of ProximalGradient.

For instance, suppose we want to solve the following optimization problem

\[\min_{w} \frac{1}{2n} ||Xw - y||^2 + \text{l1reg} \cdot ||w||_1\]

which corresponds to the choice \(g(w, \text{l1reg}) = \text{l1reg} \cdot ||w||_1\). The corresponding prox operator is prox_lasso. We can therefore write:

from jaxopt import ProximalGradient
from jaxopt.prox import prox_lasso

def least_squares(w, data):
  X, y = data
  residuals = jnp.dot(X, w) - y
  return jnp.mean(residuals ** 2)

l1reg = 1.0
pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

Note that prox_lasso has a hyperparameter l1reg, which controls the \(L_1\) regularization strength. As shown above, we can specify it in the run method using the hyperparams_prox argument The remaining arguments are passed to the objective function, here least_squares.

Numerous proximal operators are available, see below.

Differentiation

In some applications, it is useful to differentiate the solution of the solver with respect to some hyperparameters. Continuing the previous example, we can now differentiate the solution w.r.t. l1reg:

def solution(l1reg):
  pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
  return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

print(jax.jacobian(solution)(l1reg))

Under the hood, we use the implicit function theorem if implicit_diff=True and autodiff of unrolled iterations if implicit_diff=False. See the implicit differentiation section for more details.

Block coordinate descent

jaxopt.BlockCoordinateDescent(fun, block_prox)

Block coordinate solver.

Contrary to other solvers, jaxopt.BlockCoordinateDescent only works with composite linear objective functions.

Example:

from jaxopt import objective
from jaxopt import prox

l1reg = 1.0
w_init = jnp.zeros(n_features)
bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

Proximal operators

Proximal gradient and block coordinate descent do not access \(g(x, \lambda)\) directly but instead require its associated proximal operator. It is defined as:

\[\text{prox}_{g}(x', \lambda, \eta) := \underset{x}{\text{argmin}} ~ \frac{1}{2} ||x' - x||^2 + \eta g(x, \lambda).\]

The following operators are available.

jaxopt.prox.make_prox_from_projection(projection)

Transforms a projection into a proximal operator.

jaxopt.prox.prox_none(x[, hyperparams, scaling])

Proximal operator for \(g(x) = 0\), i.e., the identity function.

jaxopt.prox.prox_lasso(x[, l1reg, scaling])

Proximal operator for the l1 norm, i.e., soft-thresholding operator.

jaxopt.prox.prox_non_negative_lasso(x[, ...])

Proximal operator for the l1 norm on the non-negative orthant.

jaxopt.prox.prox_elastic_net(x[, ...])

Proximal operator for the elastic net.

jaxopt.prox.prox_group_lasso(x[, l2reg, scaling])

Proximal operator for the l2 norm, i.e., block soft-thresholding operator.

jaxopt.prox.prox_ridge(x[, l2reg, scaling])

Proximal operator for the squared l2 norm.

jaxopt.prox.prox_non_negative_ridge(x[, ...])

Proximal operator for the squared l2 norm on the non-negative orthant.