jaxopt.tree_util.tree_scalar_mul

jaxopt.tree_util.tree_scalar_mul(scalar, tree_x)[source]

Compute scalar * tree_x.