jaxopt.tree_util.tree_l2_norm

jaxopt.tree_util.tree_l2_norm(tree_x, squared=False)[source]

Compute the l2 norm ||tree_x||.