jaxopt.prox.prox_ridge

jaxopt.prox.prox_ridge(x, l2reg=1.0, scaling=1.0)[source]

Proximal operator for the squared l2 norm.

\[\underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2 + \text{scaling} \cdot \text{l2reg} \cdot ||y||_2^2\]
Parameters
  • x (Any) – input pytree.

  • l2reg (Optional[float]) – regularization strength.

  • scaling – a scaling factor.

Return type

Any

Returns

output pytree, with the same structure as x.