jaxopt.projection.projection_sparse_simplex

jaxopt.projection.projection_sparse_simplex(x, max_nz, use_approx_max_nz=False, value=1.0)[source]

Projection onto the simplex with cardinality constraint (maximum number of non-zero elements).

\[\underset{p}{\text{argmin}} ~ ||x - p||_2^2 \quad \textrm{subject to} \quad p \ge 0, p^\top 1 = \text{value}, ||p||_0 \le \text{max_nz}\]
Parameters
  • x (Array) – vector to project, an array of shape (n,).

  • max_nz (int) – max nonzero values to keep

  • use_approx_max_nz (bool) – when set to True, use jax.lax.approx_max_k to return max values and their indices in an approximate manner (default: False).

  • value (float) – value p should sum to (default: 1.0).

Return type

Array

Returns

projected vector, an array with the same shape as x.

References

Sparse projections onto the simplex Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch ICML 2013 https://arxiv.org/abs/1206.1529