jaxopt.BoxOSQP
- class jaxopt.BoxOSQP(matvec_Q=None, matvec_A=None, fun=None, check_primal_dual_infeasability='auto', sigma=1e-06, momentum=1.6, eq_qp_solve='cg', rho_start=0.1, rho_min=1e-06, rho_max=1000000.0, stepsize_updates_frequency=10, primal_infeasible_tol=0.001, dual_infeasible_tol=0.001, maxiter=4000, tol=0.001, termination_check_frequency=5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
Operator Splitting Solver for Quadratic Programs.
Jax implementation of the celebrated GPU-OSQP [1,3] based on ADMM. Suppports jit, vmap, matvecs, pytrees and fun.
Refer to the doc of init_state method for the meaning of the parameters in run and update methods.
It solves convex problems of the form
\[\begin{split}\begin{aligned} \min_{x,z} \quad & \frac{1}{2}xQx + c^Tx\\ \textrm{s.t.} \quad & Ax=z\\ & l\leq z\leq u \\ \end{aligned}\end{split}\]Equality constraints are obtained by setting l = u. If the inequality is one-sided then
jnp.inf can be used for u, and ``-jnp.inf
for l.P must be a positive semidefinite (PSD) matrix.
The Lagrangian is given by
\[\mathcal{L} = \frac{1}{2}x^TQx + c^Tx + y^T(Ax-z) + \mu^T (z-u) + \phi^T (l-z)\]Primal variables: \(x, z\)
Dual Eq variables: \(y\)
Dual Ineq variables: \(\mu, \phi\)
ADMM computes \(y\) at each iteration. \(\mu\) and \(\phi\) can be deduced from \(y\).
Defaults values for hyper-parameters come from: https://github.com/osqp/osqp/blob/master/include/constants.h
- Parameters
matvec_Q (Optional[Callable]) –
matvec_A (Optional[Callable]) –
fun (Optional[Callable]) –
check_primal_dual_infeasability (Union[str, bool]) –
sigma (float) –
momentum (float) –
eq_qp_solve (str) –
rho_start (float) –
rho_min (float) –
rho_max (float) –
stepsize_updates_frequency (int) –
primal_infeasible_tol (float) –
dual_infeasible_tol (float) –
maxiter (int) –
tol (float) –
termination_check_frequency (int) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- matvec_Q
(optional) a Callable matvec_Q(params_Q, x). By default, matvec_Q(P, x) = tree_dot(P, x), where the pytree Q = params_Q matches x structure. The shape of primal variables may be inferred from params_obj = (matvec_Q, c).
- Type
Optional[Callable]
- matvec_A
(optional) a Callable matvec_A(params_A, x). By default, matvec_A(A, x) = tree_dot(A, x), where tree pytree A = params_A matches x structure.
- Type
Optional[Callable]
- fun
(optional) a function with signature fun(params, params_obj) that is promised to be a quadratic polynomial convex with respect to params, i.e fun can be written
fun(x, params_obj) = 0.5*jnp.dot(x, jnp.dot(Q, x)) + jnp.dot(c, x) + cste
with params_obj a pytree that contains the parameters of the objective function. (Q, c) do not need to be explicited in params_obj by the user: c will be inferred by Jaxopt,
and the operator x -> Qx will be computed upon request.
fun
incompatible with the specification ofmatvec_Q
. Note that the shape of primal cannot be inferred from params_obj anymore, so the user should provide it in init_params. This API is provided for convenience, but note that since fun uses Jax’s autodiff under the hood, it can be slower than matvec_Q, especially when used in conjunction with implicit differentiation.- Type
Optional[Callable]
- check_primal_dual_infeasability
if True populates the
status
field ofstate
with one ofBoxOSQP.PRIMAL_INFEASIBLE
,BoxOSQP.DUAL_INFEASIBLE
. If False it improves speed but does not check feasability. If the problem is primal or dual infeasible, and jit=False, then a ValueError exception is raised. If “auto”, it will be True if jit=False and False otherwise. (default: “auto”)- Type
Union[str, bool]
- sigma
ridge regularization parameter in linear system.
- Type
float
- momentum
relaxation parameter (default: 1.6), must belong to the open interval (0,2).
momentum=1
=> no relaxation.momentum<1
=> under-relaxation.momentum>1
=> over-relaxation. Boyd [2, p21] suggests chosing momentum in [1.5, 1.8].- Type
float
- eq_qp_solve
‘cg’, ‘cg+jacobi’ or ‘lu’ (default: ‘cg’). ‘cg’ is conjugate gradient: an indirect solver that works with matvecs or pytree of matrices. ‘cg+jacobi’ is conjugate gradient with Jacobi preconditioning: only works on pytree of matrices
but can provide speedup.
‘lu’ is LU factorization: a direct solver that only work on pytree of matrices.
- Type
str
- rho_start
initial learning rate (default: 1e-1).
- Type
float
- rho_min
minimum learning rate (default: 1e-6).
- Type
float
- rho_max
maximum learning rate (default: 1e6).
- Type
float
- stepsize_updates_frequency
frequency of stepsize updates (default: 10). One every stepsize_updates_frequency updates computes a new stepsize.
- Type
int
- primal_infeasible_tol
relative tolerance for primal infeasability detection (default: 1e-3).
- Type
float
- dual_infeasible_tol
relative tolerance for dual infeasability detection (default: 1e-3).
- Type
float
- maxiter
maximum number of iterations (default: 4000).
- Type
int
- tol
absolute tolerance for stoping criterion (default: 1e-3).
- Type
float
- termination_check_frequency
frequency of termination check. (default: 5). One every termination_check_frequency the error is computed.
- Type
int
- verbose
If verbose=1, print error at each iteration. If verbose=2, also print stepsizes and primal/dual variables. Warning: verbose>0 will automatically disable jit.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
References
[1] Stellato, B., Banjac, G., Goulart, P., Bemporad, A. and Boyd, S., 2020. OSQP: An operator splitting solver for quadratic programs. Mathematical Programming Computation, 12(4), pp.637-672.
[2] Boyd, S., Parikh, N., Chu, E., Peleato, B. and Eckstein, J., 2010. Distributed Optimization and Statistical Learning via the Alternating Direction Method of Multipliers. Machine Learning, 3(1), pp.1-122.
[3] Schubiger, M., Banjac, G. and Lygeros, J., 2020. GPU acceleration of ADMM for large-scale quadratic programming. Journal of Parallel and Distributed Computing, 144, pp.55-67.
- __init__(matvec_Q=None, matvec_A=None, fun=None, check_primal_dual_infeasability='auto', sigma=1e-06, momentum=1.6, eq_qp_solve='cg', rho_start=0.1, rho_min=1e-06, rho_max=1000000.0, stepsize_updates_frequency=10, primal_infeasible_tol=0.001, dual_infeasible_tol=0.001, maxiter=4000, tol=0.001, termination_check_frequency=5, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
matvec_Q (Optional[Callable]) –
matvec_A (Optional[Callable]) –
fun (Optional[Callable]) –
check_primal_dual_infeasability (Union[str, bool]) –
sigma (float) –
momentum (float) –
eq_qp_solve (str) –
rho_start (float) –
rho_min (float) –
rho_max (float) –
stepsize_updates_frequency (int) –
primal_infeasible_tol (float) –
dual_infeasible_tol (float) –
maxiter (int) –
tol (float) –
termination_check_frequency (int) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
([matvec_Q, matvec_A, fun, ...])attribute_names
()attribute_values
()init_params
(init_x, params_obj, params_eq, ...)Return default KKTSolution for initialization of the solver state.
init_state
(init_params, params_obj, ...)Initializes the solver state.
l2_optimality_error
(params, params_obj, ...)Computes the L2 norm of the KKT residuals.
run
([init_params, params_obj, params_eq, ...])Return primal/dual variables.
update
(params, state, params_obj, params_eq, ...)Perform BoxOSQP step.
Attributes
DUAL_INFEASIBLE
PRIMAL_INFEASIBLE
SOLVED
UNSOLVED
- init_params(init_x, params_obj, params_eq, params_ineq)[source]
Return default KKTSolution for initialization of the solver state.
- Parameters
init_x (
Any
) – initial primal variable.params_obj (
Union
[Tuple
[Any
,Any
],Any
]) – parameters of the objective function (see doc of init_state method).params_eq (
Any
) – parameters of the equality constraints (see doc of init_state method).params_ineq (
Tuple
[Any
,Any
]) – parameters of the inequality constraints (see doc of init_state method).
- Returns
default parameters for initialization.
- Return type
init_params
- init_state(init_params, params_obj, params_eq, params_ineq)[source]
Initializes the solver state.
- Parameters
init_params (
KKTSolution
) – initial primal and dual variables (KKTSolution).params_obj (
Union
[Tuple
[Any
,Any
],Any
]) – parameters of the quadratic objective, can be: a tuple (Q, c) with Q a pytree of matrices, or a tuple (params_Q, c) ifmatvec_Q
is provided, or an arbitrary pytree iffun
is provided.params_eq (
Any
) – parameters of the equality constraints (see doc of run method).params_ineq (
Tuple
[Any
,Any
]) – parameters of the inequality constraints (see doc of run method).
- Returns
A BoxOSQPState object.
- l2_optimality_error(params, params_obj, params_eq, params_ineq)[source]
Computes the L2 norm of the KKT residuals.
- Return type
OptStep
- Parameters
params (KKTSolution) –
params_obj (Tuple[Any, Any]) –
params_eq (Any) –
params_ineq (Tuple[Any, Any]) –
- run(init_params=None, params_obj=None, params_eq=None, params_ineq=None)[source]
Return primal/dual variables.
- Parameters
init_params (
Optional
[Any
]) – (optional) initial KKTSolution. Must be provided iffun
is not None.params_obj (
Optional
[Tuple
[Optional
[Any
],Any
]]) – parameters of objective, can be: a tuple (Q, c) with Q a pytree of matrices, or a tuple (params_Q, c) ifmatvec_Q
is provided, or an arbitrary pytree iffun
is provided.params_eq (
Optional
[Any
]) – (optional) params_A.params_ineq (
Optional
[Tuple
[Any
,Any
]]) – pair (l, u).
- Return type
OptStep