jaxopt.tree_util.tree_add_scalar_mul

jaxopt.tree_util.tree_add_scalar_mul(tree_x, scalar, tree_y)[source]

Compute tree_x + scalar * tree_y.