jaxopt.tree_util.tree_sum

jaxopt.tree_util.tree_sum(tree_x)[source]

Compute sum(tree_x).