Quadratic programming
This section is concerned with minimizing quadratic functions subject to equality and/or inequality constraints, also known as quadratic programming.
JAXopt supports several solvers for quadratic programming. The solver specificities are summarized in the table below. The best choice will depend on the usage.
Name |
jit |
pytree |
matvec |
quad. fun |
precision |
stability |
speed |
derivative |
input format |
---|---|---|---|---|---|---|---|---|---|
yes |
yes |
yes |
yes |
++ |
+ |
+++ |
implicit |
(Q, c), (A, b) |
|
no |
no |
no |
no |
+++ |
+++ |
+ |
implicit |
(Q, c), (A, b), (G, h) |
|
yes |
yes |
yes |
yes |
+ |
++ |
++ |
implicit |
(Q, c), (A, b), (G, h) |
|
yes |
yes |
yes |
yes |
+ |
++ |
++ |
both |
(Q, c), A, (l, u) |
|
yes |
no |
no |
no |
++ |
+++ |
++ |
both |
(Q, c), (l, u) |
jit: the algorithm can be used with jit or vmap, on GPU/TPU.
pytree: the algorithm can be used with pytrees of matrices (see below).
matvec: the QP parameters can be given as matvec instead of dense matrices (see below).
quad. fun: the algorithm can be used with a quadratic function argument (see below).
precision: accuracy expected when the solver succeeds to converge.
stability: capacity to handle badly scaled problems and matrices with poor conditioning.
speed: typical speed on big instances to reach its maximum accuracy.
derivative: whether differentiation is supported only via implicit differentiation, or by both implicit differentiation and unrolling.
input format: see subsections below.
This table is given as rule of thumb only; on some particular instances
some solvers may behave unexpectedly better (or worse!) than others.
In case of difficulties, we suggest to test different combinations of
algorithms, maxiter
and tol
values.
Warning
These algorithms are guaranteed to converge on convex problems only. Hence, the matrix \(Q\) must be positive semi-definite (PSD).
Equality-constrained QPs
The problem takes the form:
|
Quadratic programming with equality constraints only. |
This class is optimized for QPs with equality constraints only: it supports jit, pytrees and matvec. It is based on the KKT conditions of the problem.
Example:
from jaxopt import EqualityConstrainedQP
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
qp = EqualityConstrainedQP()
sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params
print(sol.primal)
print(sol.dual_eq)
Ill-posed problems
This solver is the fastest for well-posed problems, but can behave poorly on badly scaled matrices, or with redundant constraints.
If the solver struggles to converge,
it is possible to enable
iterative refinement.
This can be done by setting refine_regularization
and refine_maxiter
:
from jaxopt.eq_qp import EqualityConstrainedQP
Q = 2 * jnp.array([[3000., 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
qp = EqualityConstrainedQP(tol=1e-5, refine_regularization=3., refine_maxiter=50)
sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params
print(sol.primal)
print(sol.dual_eq)
print(qp.l2_optimality_error(sol, params_obj=(Q, c), params_eq=(A, b)))
General QPs
The problem takes the form:
CvxpyQP
The wrapper over
CVXPY
is a solver that runs in float64
precision.
However, it is not jittable, and does not support matvec and pytrees.
|
Wraps CVXPY's quadratic solver with implicit diff support. |
Example:
from jaxopt import CvxpyQP
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
h = jnp.array([0.0, 0.0])
qp = CvxpyWrapper()
sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params
print(sol.primal)
print(sol.dual_eq)
print(sol.dual_ineq)
It is also possible to specify only equality constraints or only inequality
constraints by setting params_eq
or params_ineq
to None
.
OSQP
This solver is a pure JAX re-implementation of the OSQP algorithm.
It is jittable, supports pytrees and matvecs, but the precision is usually
lower than CvxpyQP
when run in float32 precision.
It is meant as a drop-in replacement for CvxpyQP
, but it
is a wrapper over BoxOSQP
.
Hence we recommend to use BoxOSQP
to avoid a costly problem transformation.
|
OSQP solver for general quadratic programming. |
Example:
from jaxopt import OSQP
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
h = jnp.array([0.0, 0.0])
qp = OSQP()
sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params
print(sol.primal)
print(sol.dual_eq)
print(sol.dual_ineq)
See jaxopt.BoxOSQP
for a full description of the parameters.
Box-constrained QPs, with equality
The problem takes the form:
|
Operator Splitting Solver for Quadratic Programs. |
jaxopt.BoxOSQP
uses the same underlying solver as jaxopt.OSQP
but accepts problems in the above box-constrained format instead. The bounds
u
(resp. l
) can be set to inf
(resp. -inf
) if required.
Equality can be enforced with l = u
.
Example:
from jaxopt import BoxOSQP
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])
l = jnp.array([1.0, -jnp.inf, -jnp.inf])
u = jnp.array([1.0, 0.0, 0.0])
qp = BoxOSQP()
sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params
print(sol.primal)
print(sol.dual_eq)
print(sol.dual_ineq)
If required the algorithm can be sped up by setting
check_primal_dual_infeasability
to False
, and by setting
eq_qp_preconditioner
to "jacobi"
(when possible).
Note
The tol
parameter controls the tolerance of the stopping criterion, which
is based on the primal and dual residuals. For over-constrained problems, or
badly-scaled matrices, the residuals can be high, and it may be difficult to
set tol
appropriately. In this case, it is better to tune maxiter
instead.
Box-constrained QPs, without equality
The problem takes the form:
|
Coordinate descent solver for box-constrained QPs. |
jaxopt.BoxCDQP
uses a coordinate descent solver. The solver returns only
the primal solution.
Example:
from jaxopt import BoxCDQP
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, -1.0])
l = jnp.array([0.0, 0.0])
u = jnp.array([1.0, 1.0])
qp = BoxCDQP()
init = jnp.zeros(2)
sol = qp.run(init, params_obj=(Q, c), params_ineq=(l, u)).params
print(sol)
Unconstrained QPs
For completeness, we also briefly describe how to solve unconstrained quadratics of the form:
The optimality condition rewrites \(\nabla \frac{1}{2} x^\top Q x + c^\top x=Qx+c=0\). Therefore, this is equivalent to solving the linear system \(Qx=-c\). Since the matrix \(Q\) is assumed PSD, one of the best algorithms is conjugate gradient. In JAXopt, this can be done as follows:
from jaxopt.linear_solve import solve_cg
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
matvec = lambda x: jnp.dot(Q, x)
sol = solve_cg(matvec, b=-c)
print(sol)
Pytree of matrices API
Solvers EqualityConstrainedQP
, OSQP
and BoxOSQP
support
the pytree of matrices API. It means that the matrices Q, A, G can be provided
as block diagonal operator whose blocks are leaves of pytrees.
This corresponds to separable problems that can be solved in parallel (one for each leaf).
- It offers several advantages:
This model of parallelism succeeds even if all the problems have different shapes, contrary to the jax.vmap API.
This formulation is more efficient than a single big matrix, especially when there are a lot of blocks, and when the blocks themselves are small.
The tolerance is globally defined and shared by all the problems, and the number of iterations is the same for all the problems.
We illustrate below the parallel solving of two problems with different shapes:
Q1 = jnp.array([[1.0, -0.5],
[-0.5, 1.0]])
Q2 = jnp.array([[2.0]])
Q = {'problem1': Q1, 'problem2': Q2}
c1 = jnp.array([-0.4, 0.3])
c2 = jnp.array([0.1])
c = {'problem1': c1, 'problem2': c2}
a1 = jnp.array([[-0.5, 1.5]])
a2 = jnp.array([[10.0]])
A = {'problem1': a1, 'problem2': a2}
b1 = jnp.array([0.3])
b2 = jnp.array([5.0])
b = {'problem1': b1, 'problem2': b2}
qp = EqualityConstrainedQP(tol=1e-3)
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
# Solve the two problems in parallel with a single call.
sol = qp.run(**hyperparams).params
print(sol.primal['problem1'], sol.primal['problem2'])
Matvec API
Solvers EqualityConstrainedQP
, OSQP
and BoxOSQP
support the matvec API.
It means that the user can provide a function matvec
that computes the matrix-vector product,
either in the objective x -> Qx or in the constraints x -> Ax, x -> Gx.
- It offers several advantages:
the code is easier to read and closer to the mathematical formulation of the problem.
sparse matrix-vector products are available, which can be much faster than a dense one.
the derivatives w.r.t (params_obj, params_eq, params_ineq) may be easier to compute than materializing the full matrix.
it is faster than the quadratic function API.
This is the recommended API to use when the matrices are not block diagonal operators, especially when there are other sparsity patterns involved, or in conjunction with implicit differentiation:
# Objective:
# min ||data @ x - targets||_2^2 + 2 * n * lam ||x||_1
#
# With BoxOSQP formulation:
#
# min_{x, y, t} y^Ty + 2*n*lam 1^T t
# under targets = data @ x - y
# 0 <= x + t <= infinity
# -infinity <= x - t <= 0
data, targets = datasets.make_regression(n_samples=10, n_features=3, random_state=0)
lam = 10.0
def matvec_Q(params_Q, xyt):
del params_Q # unused
x, y, t = xyt
return jnp.zeros_like(x), 2 * y, jnp.zeros_like(t)
c = jnp.zeros(data.shape[1]), jnp.zeros(data.shape[0]), 2*n*lam * jnp.ones(data.shape[1])
def matvec_A(params_A, xyt):
x, y, t = xyt
residuals = params_A @ x - y
return residuals, x + t, x - t
l = targets, jnp.zeros_like(c[0]), jnp.full(data.shape[1], -jnp.inf)
u = targets, jnp.full(data.shape[1], jnp.inf), jnp.zeros_like(c[0])
hyper_params = dict(params_obj=(None, c), params_eq=data, params_ineq=(l, u))
osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=1e-2)
sol, state = osqp.run(None, **hyper_params)
Quadratic function API
Solvers EqualityConstrainedQP
, OSQP
and BoxOSQP
support the quadratic function API.
It means that the whole objective function x -> 1/2 x^T Q x + c^T x + K can be provided as a function
fun
that computes the quadratic function. The function must be differentiable w.r.t x.
- It offers several advantages:
the code is easier to read and closer to the mathematical formulation of the problem.
there is no need to provide the matrix Q and the vector c separately, nor to remove the constant term K.
the derivatives w.r.t (params_obj, params_eq, params_ineq) may be even easier to compute than materializing the full matrix.
- Take care that this API also have drawbacks:
the function
fun
must be differentiable w.r.t x (with Jax’s AD), even if you are not interested in the derivatives of your QP.to extract x -> Qx and c from the function, we need to compute the Hessian-vector product and the gradient of
fun
, which may be expensive.for this API init_params must be provided to run, contrary to the other APIs.
We illustrate this API with Non Negative Least Squares (NNLS):
# min_W \|Y-UW\|_F^2
# s.t. W>=0
n, m, rank = 20, 10, 3
onp.random.seed(654)
U = jax.nn.relu(onp.random.randn(n, rank))
W_0 = jax.nn.relu(onp.random.randn(rank, m))
Y = U @ W_0
def fun(W, params_obj):
Y, U = params_obj
# Write the objective as an implicit quadratic polynomial
return jnp.sum(jnp.square(Y - U @ W))
def matvec_G(params_G, W):
del params_G # unused
return -W
zeros = jnp.zeros_like(W_0)
hyper_params = dict(params_obj=(Y, U), params_eq=None, params_ineq=(None, zeros))
solver = OSQP(fun=fun, matvec_G=matvec_G)
init_W = jnp.zeros_like(W_0) # mandatory with `fun` API.
init_params = solver.init_params(init_W, **hyper_params)
W_sol = solver.run(init_params=init_params, **hyper_params).params.primal
This API is not recommended for large-scale problems or nested differentiations, use matvec API instead.
Implicit differentiation pitfalls
When using implicit differentiation, the parameters w.r.t which we differentiate must be passed to params_obj, params_eq or params_ineq. They should not be captured from the global scope by fun or matvec. We illustrate below this common mistake:
def wrong_solver(Q): # don't do this!
def matvec_Q(params_Q, x):
del params_Q # unused
# error! Q is captured from the global scope.
# it does not fail now, but it will fail later.
return jnp.dot(Q, x)
c = jnp.zeros(Q.shape[0])
A = jnp.array([[1.0, 2.0]])
b = jnp.array([1.0])
eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q)
sol = eq_qp.run(None, params_obj=(None, c), params_eq=(A, b)).params
loss = jnp.sum(sol.primal)
return loss
Q = jnp.array([[1.0, 0.5], [0.5, 4.0]])
_ = wrong_solver(Q) # no error... but it will fail later.
_ = jax.grad(wrong_solver)(Q) # raise CustomVJPException
Also, notice that since the problems are convex, the optimum is independent of the starting point init_params. Hence, derivatives w.r.t init_params are always zero (mathematically).
The correct implementation is given below:
def correct_solver(Q):
def matvec_Q(params_Q, x):
return jnp.dot(params_Q, x)
c = jnp.zeros(Q.shape[0])
A = jnp.array([[1.0, 2.0]])
b = jnp.array([1.0])
eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q)
# Q is passed as a parameter, not captured from the global scope.
sol = eq_qp.run(None, params_obj=(Q, c), params_eq=(A, b)).params
loss = jnp.sum(sol.primal)
return loss
Q = jnp.array([[1.0, 0.5], [0.5, 4.0]])
_ = correct_solver(Q) # no error
_ = jax.grad(correct_solver)(Q) # no error