jaxopt.BoxOSQP

class jaxopt.BoxOSQP(matvec_Q=None, matvec_A=None, fun=None, check_primal_dual_infeasability='auto', sigma=1e-06, momentum=1.6, eq_qp_solve='cg', rho_start=0.1, rho_min=1e-06, rho_max=1000000.0, stepsize_updates_frequency=10, primal_infeasible_tol=0.001, dual_infeasible_tol=0.001, maxiter=4000, tol=0.001, termination_check_frequency=5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Operator Splitting Solver for Quadratic Programs.

Jax implementation of the celebrated GPU-OSQP [1,3] based on ADMM. Suppports jit, vmap, matvecs, pytrees and fun.

Refer to the doc of init_state method for the meaning of the parameters in run and update methods.

It solves convex problems of the form

\[\begin{split}\begin{aligned} \min_{x,z} \quad & \frac{1}{2}xQx + c^Tx\\ \textrm{s.t.} \quad & Ax=z\\ & l\leq z\leq u \\ \end{aligned}\end{split}\]

Equality constraints are obtained by setting l = u. If the inequality is one-sided then jnp.inf can be used for u, and ``-jnp.inf for l.

P must be a positive semidefinite (PSD) matrix.

The Lagrangian is given by

\[\mathcal{L} = \frac{1}{2}x^TQx + c^Tx + y^T(Ax-z) + \mu^T (z-u) + \phi^T (l-z)\]

Primal variables: \(x, z\)

Dual Eq variables: \(y\)

Dual Ineq variables: \(\mu, \phi\)

ADMM computes \(y\) at each iteration. \(\mu\) and \(\phi\) can be deduced from \(y\).

Defaults values for hyper-parameters come from: https://github.com/osqp/osqp/blob/master/include/constants.h

Parameters
  • matvec_Q (Optional[Callable]) –

  • matvec_A (Optional[Callable]) –

  • fun (Optional[Callable]) –

  • check_primal_dual_infeasability (Union[str, bool]) –

  • sigma (float) –

  • momentum (float) –

  • eq_qp_solve (str) –

  • rho_start (float) –

  • rho_min (float) –

  • rho_max (float) –

  • stepsize_updates_frequency (int) –

  • primal_infeasible_tol (float) –

  • dual_infeasible_tol (float) –

  • maxiter (int) –

  • tol (float) –

  • termination_check_frequency (int) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

matvec_Q

(optional) a Callable matvec_Q(params_Q, x). By default, matvec_Q(P, x) = tree_dot(P, x), where the pytree Q = params_Q matches x structure. The shape of primal variables may be inferred from params_obj = (matvec_Q, c).

Type

Optional[Callable]

matvec_A

(optional) a Callable matvec_A(params_A, x). By default, matvec_A(A, x) = tree_dot(A, x), where tree pytree A = params_A matches x structure.

Type

Optional[Callable]

fun

(optional) a function with signature fun(params, params_obj) that is promised to be a quadratic polynomial convex with respect to params, i.e fun can be written

fun(x, params_obj) = 0.5*jnp.dot(x, jnp.dot(Q, x)) + jnp.dot(c, x) + cste

with params_obj a pytree that contains the parameters of the objective function. (Q, c) do not need to be explicited in params_obj by the user: c will be inferred by Jaxopt,

and the operator x -> Qx will be computed upon request.

fun incompatible with the specification of matvec_Q. Note that the shape of primal cannot be inferred from params_obj anymore, so the user should provide it in init_params. This API is provided for convenience, but note that since fun uses Jax’s autodiff under the hood, it can be slower than matvec_Q, especially when used in conjunction with implicit differentiation.

Type

Optional[Callable]

check_primal_dual_infeasability

if True populates the status field of state with one of BoxOSQP.PRIMAL_INFEASIBLE, BoxOSQP.DUAL_INFEASIBLE. If False it improves speed but does not check feasability. If the problem is primal or dual infeasible, and jit=False, then a ValueError exception is raised. If “auto”, it will be True if jit=False and False otherwise. (default: “auto”)

Type

Union[str, bool]

sigma

ridge regularization parameter in linear system.

Type

float

momentum

relaxation parameter (default: 1.6), must belong to the open interval (0,2). momentum=1 => no relaxation. momentum<1 => under-relaxation. momentum>1 => over-relaxation. Boyd [2, p21] suggests chosing momentum in [1.5, 1.8].

Type

float

eq_qp_solve

‘cg’, ‘cg+jacobi’ or ‘lu’ (default: ‘cg’). ‘cg’ is conjugate gradient: an indirect solver that works with matvecs or pytree of matrices. ‘cg+jacobi’ is conjugate gradient with Jacobi preconditioning: only works on pytree of matrices

but can provide speedup.

‘lu’ is LU factorization: a direct solver that only work on pytree of matrices.

Type

str

rho_start

initial learning rate (default: 1e-1).

Type

float

rho_min

minimum learning rate (default: 1e-6).

Type

float

rho_max

maximum learning rate (default: 1e6).

Type

float

stepsize_updates_frequency

frequency of stepsize updates (default: 10). One every stepsize_updates_frequency updates computes a new stepsize.

Type

int

primal_infeasible_tol

relative tolerance for primal infeasability detection (default: 1e-3).

Type

float

dual_infeasible_tol

relative tolerance for dual infeasability detection (default: 1e-3).

Type

float

maxiter

maximum number of iterations (default: 4000).

Type

int

tol

absolute tolerance for stoping criterion (default: 1e-3).

Type

float

termination_check_frequency

frequency of termination check. (default: 5). One every termination_check_frequency the error is computed.

Type

int

verbose

If verbose=1, print error at each iteration. If verbose=2, also print stepsizes and primal/dual variables. Warning: verbose>0 will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

References

[1] Stellato, B., Banjac, G., Goulart, P., Bemporad, A. and Boyd, S., 2020. OSQP: An operator splitting solver for quadratic programs. Mathematical Programming Computation, 12(4), pp.637-672.

[2] Boyd, S., Parikh, N., Chu, E., Peleato, B. and Eckstein, J., 2010. Distributed Optimization and Statistical Learning via the Alternating Direction Method of Multipliers. Machine Learning, 3(1), pp.1-122.

[3] Schubiger, M., Banjac, G. and Lygeros, J., 2020. GPU acceleration of ADMM for large-scale quadratic programming. Journal of Parallel and Distributed Computing, 144, pp.55-67.

__init__(matvec_Q=None, matvec_A=None, fun=None, check_primal_dual_infeasability='auto', sigma=1e-06, momentum=1.6, eq_qp_solve='cg', rho_start=0.1, rho_min=1e-06, rho_max=1000000.0, stepsize_updates_frequency=10, primal_infeasible_tol=0.001, dual_infeasible_tol=0.001, maxiter=4000, tol=0.001, termination_check_frequency=5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • matvec_Q (Optional[Callable]) –

  • matvec_A (Optional[Callable]) –

  • fun (Optional[Callable]) –

  • check_primal_dual_infeasability (Union[str, bool]) –

  • sigma (float) –

  • momentum (float) –

  • eq_qp_solve (str) –

  • rho_start (float) –

  • rho_min (float) –

  • rho_max (float) –

  • stepsize_updates_frequency (int) –

  • primal_infeasible_tol (float) –

  • dual_infeasible_tol (float) –

  • maxiter (int) –

  • tol (float) –

  • termination_check_frequency (int) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__([matvec_Q, matvec_A, fun, ...])

attribute_names()

attribute_values()

init_params(init_x, params_obj, params_eq, ...)

Return default KKTSolution for initialization of the solver state.

init_state(init_params, params_obj, ...)

Initializes the solver state.

l2_optimality_error(params, params_obj, ...)

Computes the L2 norm of the KKT residuals.

run([init_params, params_obj, params_eq, ...])

Return primal/dual variables.

update(params, state, params_obj, params_eq, ...)

Perform BoxOSQP step.

Attributes

DUAL_INFEASIBLE

PRIMAL_INFEASIBLE

SOLVED

UNSOLVED

check_primal_dual_infeasability

dual_infeasible_tol

eq_qp_solve

fun

implicit_diff

implicit_diff_solve

jit

matvec_A

matvec_Q

maxiter

momentum

primal_infeasible_tol

rho_max

rho_min

rho_start

sigma

stepsize_updates_frequency

termination_check_frequency

tol

unroll

verbose

init_params(init_x, params_obj, params_eq, params_ineq)[source]

Return default KKTSolution for initialization of the solver state.

Parameters
  • init_x (Any) – initial primal variable.

  • params_obj (Union[Tuple[Any, Any], Any]) – parameters of the objective function (see doc of init_state method).

  • params_eq (Any) – parameters of the equality constraints (see doc of init_state method).

  • params_ineq (Tuple[Any, Any]) – parameters of the inequality constraints (see doc of init_state method).

Returns

default parameters for initialization.

Return type

init_params

init_state(init_params, params_obj, params_eq, params_ineq)[source]

Initializes the solver state.

Parameters
  • init_params (KKTSolution) – initial primal and dual variables (KKTSolution).

  • params_obj (Union[Tuple[Any, Any], Any]) – parameters of the quadratic objective, can be: a tuple (Q, c) with Q a pytree of matrices, or a tuple (params_Q, c) if matvec_Q is provided, or an arbitrary pytree if fun is provided.

  • params_eq (Any) – parameters of the equality constraints (see doc of run method).

  • params_ineq (Tuple[Any, Any]) – parameters of the inequality constraints (see doc of run method).

Returns

A BoxOSQPState object.

l2_optimality_error(params, params_obj, params_eq, params_ineq)[source]

Computes the L2 norm of the KKT residuals.

Return type

OptStep

Parameters
  • params (KKTSolution) –

  • params_obj (Tuple[Any, Any]) –

  • params_eq (Any) –

  • params_ineq (Tuple[Any, Any]) –

run(init_params=None, params_obj=None, params_eq=None, params_ineq=None)[source]

Return primal/dual variables.

Parameters
  • init_params (Optional[Any]) – (optional) initial KKTSolution. Must be provided if fun is not None.

  • params_obj (Optional[Tuple[Optional[Any], Any]]) – parameters of objective, can be: a tuple (Q, c) with Q a pytree of matrices, or a tuple (params_Q, c) if matvec_Q is provided, or an arbitrary pytree if fun is provided.

  • params_eq (Optional[Any]) – (optional) params_A.

  • params_ineq (Optional[Tuple[Any, Any]]) – pair (l, u).

Return type

OptStep

update(params, state, params_obj, params_eq, params_ineq)[source]

Perform BoxOSQP step.

Parameters
  • params (KKTSolution) –

  • state (BoxOSQPState) –

  • params_obj (Union[Tuple[Any, Any], Any]) –

  • params_eq (Any) –

  • params_ineq (Tuple[Any, Any]) –