jaxopt.prox.prox_non_negative_ridge

jaxopt.prox.prox_non_negative_ridge(x, l2reg=1.0, scaling=1.0)[source]

Proximal operator for the squared l2 norm on the non-negative orthant.

argminy0 12||xy||22+scalingl2reg||y||22
Parameters
  • x (Any) – input pytree.

  • l2reg (Optional[float]) – regularization strength.

  • scaling (float) – a scaling factor.

Returns

output pytree, with the same structure as x.