jaxopt.projection.projection_l2_sphere

jaxopt.projection.projection_l2_sphere(x, value=1.0)[source]

Projection onto the l2 sphere:

argminy ||xy||22subject to||y||2=value
Parameters
  • x (Any) – pytree to project.

  • value (float) – radius of the sphere.

Return type

Any

Returns

output pytree, with the same structure as x.