# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tree utilities."""
import functools
import itertools
import operator
import jax
from jax import tree_util as tu
import jax.numpy as jnp
import numpy as onp
tree_flatten = tu.tree_flatten
tree_leaves = tu.tree_leaves
tree_map = tu.tree_map
tree_reduce = tu.tree_reduce
tree_unflatten = tu.tree_unflatten
def broadcast_pytrees(*trees):
"""Broadcasts leaf pytrees to match treedef shared by the other arguments.
Args:
*trees: A `Sequence` of pytrees such that all elements that are *not* leaf
pytrees (i.e. single arrays) have the same treedef.
Returns:
The input `Sequence` of pytrees `*trees` with leaf pytrees (i.e. single
arrays) replaced by pytrees matching the treedef of non-shallow elements via
broadcasting.
Raises:
ValueError: If two or more pytrees in `*trees` that are not leaf pytrees
differ in their structure (treedef).
"""
leaves, treedef, is_leaf = [], None, []
for tree in trees:
leaves_i, treedef_i = tu.tree_flatten(tree)
is_leaf_i = tu.treedef_is_leaf(treedef_i)
if not is_leaf_i:
treedef = treedef or treedef_i
if treedef_i != treedef:
raise ValueError('Pytrees are not broadcastable.: '
f'{treedef} != {treedef_i}')
leaves.append(leaves_i)
is_leaf.append(is_leaf_i)
if treedef is not None:
max_num_leaves = max(len(leaves_i) for leaves_i in leaves)
broadcast_leaf = lambda leaf: itertools.repeat(leaf[0], max_num_leaves)
leaves = [broadcast_leaf(leaves_i) if is_leaf_i else leaves_i
for (leaves_i, is_leaf_i) in zip(leaves, is_leaf)]
return tuple(treedef.unflatten(leaves_i) for leaves_i in leaves)
# All Pytrees are leaves.
return trees
tree_add = functools.partial(tree_map, operator.add)
tree_add.__doc__ = "Tree addition."
tree_sub = functools.partial(tree_map, operator.sub)
tree_sub.__doc__ = "Tree subtraction."
tree_mul = functools.partial(tree_map, operator.mul)
tree_mul.__doc__ = "Tree multiplication."
tree_div = functools.partial(tree_map, operator.truediv)
tree_div.__doc__ = "Tree division."
[docs]def tree_scalar_mul(scalar, tree_x):
"""Compute scalar * tree_x."""
return tree_map(lambda x: scalar * x, tree_x)
[docs]def tree_add_scalar_mul(tree_x, scalar, tree_y):
"""Compute tree_x + scalar * tree_y."""
return tree_map(lambda x, y: x + scalar * y, tree_x, tree_y)
_vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)
def _vdot_safe(a, b):
return _vdot(jnp.asarray(a), jnp.asarray(b))
[docs]def tree_vdot(tree_x, tree_y):
"""Compute the inner product <tree_x, tree_y>."""
vdots = tree_map(_vdot_safe, tree_x, tree_y)
return tree_reduce(operator.add, vdots)
def _vdot_real(x, y):
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.
The result is a real float.
"""
#result = _vdot(x.real, y.real)
#if jnp.iscomplexobj(x) and jnp.iscomplexobj(y):
# result += _vdot(x.imag, y.imag)
result = _vdot(x, y).real # NOTE: without jit this is faster than variant above, no difference with jit
return result
def tree_vdot_real(tree_x, tree_y):
"""Compute the real part of the inner product <tree_x, tree_y>."""
return sum(tree_leaves(tree_map(_vdot_real, tree_x, tree_y)))
def tree_dot(tree_x, tree_y):
"""Compute leaves-wise dot product between pytree of arrays.
Useful to store block diagonal linear operators: each leaf of the tree
corresponds to a block."""
return tree_map(jnp.dot, tree_x, tree_y)
[docs]def tree_sum(tree_x):
"""Compute sum(tree_x)."""
sums = tree_map(jnp.sum, tree_x)
return tree_reduce(operator.add, sums)
[docs]def tree_l2_norm(tree_x, squared=False):
"""Compute the l2 norm ||tree_x||."""
squared_tree = tree_map(lambda leaf: jnp.square(leaf.real) + jnp.square(leaf.imag), tree_x)
sqnorm = tree_sum(squared_tree)
if squared:
return sqnorm
else:
return jnp.sqrt(sqnorm)
[docs]def tree_zeros_like(tree_x):
"""Creates an all-zero tree with the same structure as tree_x."""
return tree_map(jnp.zeros_like, tree_x)
def tree_ones_like(tree_x):
"""Creates an all-ones tree with the same structure as tree_x."""
return tree_map(jnp.ones_like, tree_x)
def tree_average(trees, weights):
"""Return the linear combination of a list of trees.
Args:
trees: tree of arrays with shape (m,...)
weights: array of shape (m,)
Returns:
a single tree that is the linear combination of all trees
"""
return tree_map(lambda x: jnp.tensordot(weights, x, axes=1), trees)
def tree_gram(a):
"""Compute Gramn matrix from the pytree of batchs of vectors.
Args:
a: pytree of arrays of shape (m,...)
Returns:
arrays of shape (m,m) of all dot products
"""
vmap_left = jax.vmap(tree_vdot, in_axes=(0,None))
vmap_right = jax.vmap(vmap_left, in_axes=(None,0))
return vmap_right(a, a)
def tree_inf_norm(tree_x):
"""Computes the infinity norm of a pytree."""
leaves_vec = tree_leaves(tree_map(jnp.ravel, tree_x))
return jnp.max(jnp.abs(jnp.concatenate(leaves_vec)))
def tree_where(cond, a, b):
"""jnp.where for trees.
Mimic broadcasting semantic of jnp.where.
cond, a and b can be arrays (including scalars) broadcastable to the leaves of
the other input arguments.
Args:
cond: pytree of booleans arrays, or single array broadcastable to the shapes
of leaves of `a` and `b`.
a: pytree of arrays, or single array broadcastable to the shapes of leaves
of `cond` and `b`.
b: pytree of arrays, or single array broadcastable to the shapes of leaves
of `cond` and `a`.
Returns:
pytree of arrays, or single array
"""
cond, a, b = broadcast_pytrees(cond, a, b)
return tree_map(jnp.where, cond, a, b)
def tree_negative(tree):
"""Computes elementwise negation -x."""
return tree_scalar_mul(-1, tree)
def tree_reciproqual(tree):
"""Computes elementwise inverse 1/x."""
return tree_map(lambda x: jnp.reciprocal(x), tree)
def tree_mean(tree):
"""Mean reduction for trees."""
leaves_avg = tree_map(jnp.mean, tree)
return tree_sum(leaves_avg) / len(tree_leaves(leaves_avg))
def tree_single_dtype(tree, convert_in_jax_dtype=True):
"""The dtype for all values in a tree, provided that all leaves share the same type.
If the leaves have different type, raise a ValueError.
Args:
tree: tree to get the dtype of
convert_in_jax_type: whether to convert the types in JAX precision.
Namely, a numpy int64 type is converted in a jax.numpy int32 type
by default unless one enables double precision using
jax.config.update("jax_enable_x64", True)
Return:
dtype shared by all leaves of the tree
"""
if convert_in_jax_dtype:
dtypes = set(
jnp.asarray(p).dtype
for p in tu.tree_leaves(tree)
if isinstance(
p, (bool, int, float, complex, onp.ndarray, jnp.ndarray)
)
)
else:
dtypes = set(
onp.asarray(p).dtype
for p in tu.tree_leaves(tree)
if isinstance(
p, (bool, int, float, complex, onp.ndarray, jnp.ndarray)
)
)
if not dtypes:
return None
if len(dtypes) == 1:
dtype = dtypes.pop()
return dtype
raise ValueError("Found more than one dtype in the tree.")
def get_real_dtype(dtype):
"""Dtype corresponding of real part of a complex dtype."""
if dtype not in [f'complex{i}' for i in [4, 8, 16, 32, 64, 128]]:
return dtype
else:
return dtype.type(0).real.dtype
def tree_conj(tree):
"""Complex conjugate of a tree."""
return tree_map(jnp.conj, tree)
def tree_real(tree):
"""Real part of a tree"""
return tree_map(jnp.real, tree)
def tree_imag(tree):
"""Imaginary part of a tree"""
return tree_map(jnp.imag, tree)