Non-smooth optimization
This section is concerned with problems of the form
where \(f(x, \theta)\) is differentiable (almost everywhere), \(x\) are the parameters with respect to which the function is minimized, \(\theta\) are optional extra arguments, \(g(x, \lambda)\) is possibly non-smooth, and \(\lambda\) are extra parameters \(g\) may depend on.
Proximal gradient
|
Proximal gradient solver. |
Instantiating and running the solver
Proximal gradient is a generalization of projected gradient descent. The non-smooth term \(g\) above is specified by
setting the corresponding proximal operator, which is achieved using the
prox
attribute of ProximalGradient
.
For instance, suppose we want to solve the following optimization problem
which corresponds to the choice \(g(w, \text{l1reg}) = \text{l1reg} \cdot ||w||_1\). The
corresponding prox
operator is prox_lasso
.
We can therefore write:
from jaxopt import ProximalGradient
from jaxopt.prox import prox_lasso
def least_squares(w, data):
X, y = data
residuals = jnp.dot(X, w) - y
return jnp.mean(residuals ** 2)
l1reg = 1.0
pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
Note that prox_lasso
has a hyperparameter
l1reg
, which controls the \(L_1\) regularization strength. As shown
above, we can specify it in the run
method using the hyperparams_prox
argument The remaining arguments are passed to the objective function, here
least_squares
.
Numerous proximal operators are available, see below.
Differentiation
In some applications, it is useful to differentiate the solution of the solver
with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. l1reg
:
def solution(l1reg):
pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
print(jax.jacobian(solution)(l1reg))
Under the hood, we use the implicit function theorem if implicit_diff=True
and autodiff of unrolled iterations if implicit_diff=False
. See the
implicit differentiation section for more details.
Block coordinate descent
|
Block coordinate solver. |
Contrary to other solvers, jaxopt.BlockCoordinateDescent
only works with
composite linear objective functions.
Example:
from jaxopt import objective
from jaxopt import prox
l1reg = 1.0
w_init = jnp.zeros(n_features)
bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
Proximal operators
Proximal gradient and block coordinate descent do not access \(g(x, \lambda)\) directly but instead require its associated proximal operator. It is defined as:
The following operators are available.
|
Transforms a projection into a proximal operator. |
|
Proximal operator for \(g(x) = 0\), i.e., the identity function. |
|
Proximal operator for the l1 norm, i.e., soft-thresholding operator. |
|
Proximal operator for the l1 norm on the non-negative orthant. |
|
Proximal operator for the elastic net. |
|
Proximal operator for the l2 norm, i.e., block soft-thresholding operator. |
|
Proximal operator for the squared l2 norm. |
|
Proximal operator for the squared l2 norm on the non-negative orthant. |