Deprecate jax.tree_util.tree_multimap

This commit is contained in:
Jake VanderPlas 2022-04-01 14:51:54 -07:00
parent 1c3edc811d
commit df1ceaeeb1
14 changed files with 60 additions and 54 deletions

View File

@ -20,6 +20,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* `DeviceArray.copy()` now returns a `DeviceArray` rather than a `np.ndarray` ({jax-issue}`#10069`)
* Deprecations:
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
* {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`).
## jaxlib 0.3.3 (Unreleased)

View File

@ -43,9 +43,9 @@ from jax import core
from jax import linear_util as lu
from jax import stages
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
from jax.tree_util import (tree_map, tree_multimap, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
tree_multimap, treedef_is_leaf, treedef_children,
tree_map, treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)
from jax._src import device_array
@ -2731,7 +2731,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
with config_explicit_device_put_scope():
return tree_multimap(_device_put_sharded, *shards)
return tree_map(_device_put_sharded, *shards)
def device_put_replicated(x: Any, devices: Sequence[xc.Device]):

View File

@ -21,7 +21,7 @@ import numpy as np
from jax import core
from jax._src import dtypes
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure,
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
treedef_children, treedef_is_leaf)
from jax._src.tree_util import _replace_nones
from jax import linear_util as lu
@ -286,7 +286,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
except ValueError:
if kws:
# if keyword arguments are included in the tree, we make adapt the error

View File

@ -22,9 +22,8 @@ from typing import (Callable, Generic, Optional, Sequence, Tuple, List, TypeVar,
from jax import core
from jax import linear_util as lu
from jax.custom_transpose import custom_transpose
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
tree_multimap, treedef_is_leaf, treedef_tuple,
register_pytree_node_class)
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, treedef_is_leaf,
treedef_tuple, register_pytree_node_class)
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
@ -198,7 +197,7 @@ class custom_jvp(Generic[ReturnValue]):
zeros = _zeros_like_pytree(primal_out)
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
for t, jvp in zip(tangents, jvps)]
tangent_out = tree_multimap(_sum_tangents, primal_out, *all_tangents_out)
tangent_out = tree_map(_sum_tangents, primal_out, *all_tangents_out)
return primal_out, tangent_out
self.defjvp(jvp)
@ -617,8 +616,8 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
try:
if not isinstance(py_cts_in, tuple):
raise ValueError
tree_multimap(append_cts,
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
tree_map(append_cts,
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
except ValueError:
_, in_tree2 = tree_flatten(py_cts_in)
msg = ("Custom VJP rule must produce an output with the same container "

View File

@ -54,7 +54,7 @@ from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
split_list, cache, extend_name_stack, wrap_name)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_multimap,
treedef_children, treedef_tuple, tree_map,
tree_leaves, tree_structure)
from jax._src import ad_util
from jax.config import config
@ -1477,7 +1477,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
ys.append(y)
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
else jax.numpy.stack((y, *ys)))
stacked_y = tree_multimap(stack, *maybe_reversed(ys))
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
@ -2196,8 +2196,8 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
raise TypeError(
f"{what} must have same type structure, got {tree1} and {tree2}.")
if not all(_map(core.typematch, avals1, avals2)):
diff = tree_multimap(_show_diff, tree_unflatten(tree1, avals1),
tree_unflatten(tree2, avals2))
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
tree_unflatten(tree2, avals2))
raise TypeError(f"{what} must have identical types, got\n{diff}.")

View File

@ -21,7 +21,7 @@ import jax.numpy as jnp
from jax import device_put
from jax import lax
from jax import scipy as jsp
from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,
from jax.tree_util import (tree_leaves, tree_map, tree_structure,
tree_reduce, Partial)
from jax._src import dtypes
@ -52,11 +52,11 @@ def _vdot_real_part(x, y):
def _vdot_real_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))
def _vdot_tree(x, y):
return sum(tree_leaves(tree_multimap(partial(
return sum(tree_leaves(tree_map(partial(
jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))
@ -73,9 +73,9 @@ def _div(tree, scalar):
return tree_map(partial(lambda v: v / scalar), tree)
_add = partial(tree_multimap, operator.add)
_sub = partial(tree_multimap, operator.sub)
_dot_tree = partial(tree_multimap, _dot)
_add = partial(tree_map, operator.add)
_sub = partial(tree_map, operator.sub)
_dot_tree = partial(tree_map, _dot)
@Partial
@ -162,12 +162,12 @@ def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
shat = M(s)
t = A(shat)
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases?
x_ = tree_multimap(partial(jnp.where, exit_early),
_add(x, _mul(alpha_, phat)),
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
)
r_ = tree_multimap(partial(jnp.where, exit_early),
s, _sub(s, _mul(omega_, t)))
x_ = tree_map(partial(jnp.where, exit_early),
_add(x, _mul(alpha_, phat)),
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
)
r_ = tree_map(partial(jnp.where, exit_early),
s, _sub(s, _mul(omega_, t)))
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
k_ = jnp.where((rho_ == 0), -10, k_)
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_
@ -308,7 +308,7 @@ def _project_on_columns(A, v):
"""
Returns A.T.conj() @ v.
"""
v_proj = tree_multimap(
v_proj = tree_map(
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
)
return tree_reduce(operator.add, v_proj)
@ -400,7 +400,7 @@ def _kth_arnoldi_iteration(k, A, M, V, H):
tol = eps * v_norm_0
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
h = h.at[k + 1].set(v_norm_1)
H = H.at[k, :].set(h)

View File

@ -37,7 +37,7 @@ from jax._src import dtypes as _dtypes
from jax import lax
from jax._src.config import flags, bool_env, config
from jax._src.util import prod, unzip2
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from jax.tree_util import tree_map, tree_all, tree_reduce
from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax.interpreters import mlir
@ -222,12 +222,12 @@ def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''):
def check_eq(xs, ys, err_msg=''):
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
tree_all(tree_multimap(assert_close, xs, ys))
tree_all(tree_map(assert_close, xs, ys))
def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol,
err_msg=err_msg)
tree_all(tree_multimap(assert_close, xs, ys))
tree_all(tree_map(assert_close, xs, ys))
def _check_dtypes_match(xs, ys):
def _assert_dtypes_match(x, y):
@ -236,13 +236,13 @@ def _check_dtypes_match(xs, ys):
else:
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
_dtypes.canonicalize_dtype(_dtype(y)))
tree_all(tree_multimap(_assert_dtypes_match, xs, ys))
tree_all(tree_map(_assert_dtypes_match, xs, ys))
def inner_prod(xs, ys):
def contract(x, y):
return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1)))
return tree_reduce(np.add, tree_multimap(contract, xs, ys))
return tree_reduce(np.add, tree_map(contract, xs, ys))
def _safe_subtract(x, y, *, dtype):
@ -251,9 +251,9 @@ def _safe_subtract(x, y, *, dtype):
return np.where(np.equal(x, y), np.array(0, dtype),
np.subtract(x, y, dtype=dtype))
add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x)))
sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
safe_sub = partial(tree_multimap,
add = partial(tree_map, lambda x, y: np.add(x, y, dtype=_dtype(x)))
sub = partial(tree_map, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
safe_sub = partial(tree_map,
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))

View File

@ -20,6 +20,7 @@ import operator as op
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, List,
Dict, Type, TypeVar, overload, TYPE_CHECKING, NamedTuple)
import textwrap
import warnings
from jax._src.lib import pytree
@ -179,7 +180,11 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
tree_multimap = tree_map
def tree_multimap(*args, **kwargs):
"""Deprecated alias of :func:`jax.tree_util.tree_map`"""
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
'instead as a drop-in replacement.', FutureWarning)
return tree_map(*args, **kwargs)
# TODO(mattjj,phawkins): consider removing this function
def _process_pytree(process_node, tree):

View File

@ -421,7 +421,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
for dim_spec in in_spec),
dtype=tf.float32)
return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)
return tree_util.tree_map(polymorphic_shape_to_tensorspec, polymorphic_shapes)
def CountLargeTfConstants(self, tf_fun: Callable, *args,
at_least=256):

View File

@ -47,7 +47,8 @@ from jax._src.tree_util import (
tree_flatten as tree_flatten,
tree_leaves as tree_leaves,
tree_map as tree_map,
tree_multimap as tree_multimap,
# TODO(jakevdp) remove tree_multimap once deprecation is complete.
tree_multimap,
tree_reduce as tree_reduce,
tree_structure as tree_structure,
tree_transpose as tree_transpose,

View File

@ -30,7 +30,7 @@ from jax import numpy as jnp
from jax import linear_util as lu
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce, tree_leaves
from jax.interpreters import partial_eval as pe
from jax._src import test_util as jtu
@ -147,16 +147,16 @@ def fwd_deriv(f):
class CoreTest(jtu.JaxTestCase):
def test_tree_multimap(self):
def test_tree_map(self):
xs = ({'a': 1}, [2, 3])
ys = ({'a': 10}, [20, 30])
ys_bad = ({'a': 10, 'b': 10}, [20, 30])
zs = ({'a': 11}, [22, 33])
f = lambda x, y: x + y
assert tree_multimap(f, xs, ys) == zs
assert tree_map(f, xs, ys) == zs
try:
tree_multimap(f, xs, ys_bad)
tree_map(f, xs, ys_bad)
assert False
except (TypeError, ValueError):
pass
@ -170,7 +170,7 @@ class CoreTest(jtu.JaxTestCase):
flat, treedef = tree_flatten(tree)
assert flat == [1, 2, 3, 4, 5]
tree2 = tree_unflatten(treedef, flat)
nodes_equal = tree_multimap(operator.eq, tree, tree2)
nodes_equal = tree_map(operator.eq, tree, tree2)
assert tree_reduce(operator.and_, nodes_equal)
@parameterized.named_parameters(

View File

@ -2570,7 +2570,7 @@ class LaxTest(jtu.JaxTestCase):
operands = {'x': [np.ones(5), np.arange(5)]}
init_values = {'x': [0., 0]}
result = lax.reduce(operands, init_values,
lambda x, y: tree_util.tree_multimap(lax.add, x, y),
lambda x, y: tree_util.tree_map(lax.add, x, y),
[0])
self.assertDictEqual(result, {'x': [5., 10.]})

View File

@ -354,7 +354,7 @@ class PythonPmapTest(jtu.JaxTestCase):
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
assert_allclose = partial(tree_util.tree_multimap,
assert_allclose = partial(tree_util.tree_map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
@ -372,7 +372,7 @@ class PythonPmapTest(jtu.JaxTestCase):
tree_f = lambda f: partial(tree_util.tree_map, f)
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
assert_allclose = partial(tree_util.tree_multimap,
assert_allclose = partial(tree_util.tree_map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
@ -2381,7 +2381,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
return {'a': x}
device_count = jax.device_count()
x = jnp.arange(device_count)
tree_util.tree_multimap(self.assertAllClose, f(x), {'a': x})
tree_util.tree_map(self.assertAllClose, f(x), {'a': x})
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{in_axes}_{out_axes}",

View File

@ -242,18 +242,18 @@ class TreeTest(jtu.JaxTestCase):
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
def testTreeMultimap(self):
def testTreeMap(self):
x = ((1, 2), (3, 4, 5))
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y)
self.assertEqual(out, (((1, [3]), (2, None)),
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
def testTreeMultimapWithIsLeafArgument(self):
def testTreeMapWithIsLeafArgument(self):
x = ((1, 2), [3, 4, 5])
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y,
is_leaf=lambda n: isinstance(n, list))
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y,
is_leaf=lambda n: isinstance(n, list))
self.assertEqual(out, (((1, [3]), (2, None)),
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))