jaxopt.prox.prox_elastic_net

jaxopt.prox.prox_elastic_net(x, hyperparams=None, scaling=1.0)[source]

Proximal operator for the elastic net.

\[\underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2 + \text{scaling} \cdot \text{hyperparams[0]} \cdot g(y)\]

where \(g(y) = ||y||_1 + \text{hyperparams[1]} \cdot 0.5 \cdot ||y||_2^2\).

Parameters
  • x (Any) – input pytree.

  • hyperparams (Optional[Tuple[Any, Any]]) – a tuple, where both hyperparams[0] and hyperparams[1] can be either floats or pytrees with the same structure as x.

  • scaling (float) – a scaling factor.

Return type

Any

Returns

output pytree, with the same structure as x.