From df1ceaeeb11efc7c5af1ad2dd102857128c23b26 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 1 Apr 2022 14:51:54 -0700 Subject: [PATCH] Deprecate jax.tree_util.tree_multimap --- CHANGELOG.md | 1 + jax/_src/api.py | 6 ++-- jax/_src/api_util.py | 4 +-- jax/_src/custom_derivatives.py | 11 ++++---- jax/_src/lax/control_flow.py | 8 +++--- jax/_src/scipy/sparse/linalg.py | 28 +++++++++---------- jax/_src/test_util.py | 16 +++++------ jax/_src/tree_util.py | 7 ++++- jax/experimental/jax2tf/tests/tf_test_util.py | 2 +- jax/tree_util.py | 3 +- tests/core_test.py | 10 +++---- tests/lax_test.py | 2 +- tests/pmap_test.py | 6 ++-- tests/tree_util_test.py | 10 +++---- 14 files changed, 60 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95923e370..08c13d082 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * `DeviceArray.copy()` now returns a `DeviceArray` rather than a `np.ndarray` ({jax-issue}`#10069`) * Deprecations: * {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`). + * {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`). ## jaxlib 0.3.3 (Unreleased) diff --git a/jax/_src/api.py b/jax/_src/api.py index 5846a1048..8361d753d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -43,9 +43,9 @@ from jax import core from jax import linear_util as lu from jax import stages from jax.core import eval_jaxpr -from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, +from jax.tree_util import (tree_map, tree_multimap, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, - tree_multimap, treedef_is_leaf, treedef_children, + tree_map, treedef_is_leaf, treedef_children, Partial, PyTreeDef, all_leaves, treedef_tuple) from jax._src import device_array @@ -2731,7 +2731,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): return pxla.make_sharded_device_array(stacked_aval, None, buffers) with config_explicit_device_put_scope(): - return tree_multimap(_device_put_sharded, *shards) + return tree_map(_device_put_sharded, *shards) def device_put_replicated(x: Any, devices: Sequence[xc.Device]): diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index dee77a7e7..7451c014a 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -21,7 +21,7 @@ import numpy as np from jax import core from jax._src import dtypes from jax._src.tree_util import ( - PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure, + PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure, treedef_children, treedef_is_leaf) from jax._src.tree_util import _replace_nones from jax import linear_util as lu @@ -286,7 +286,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): axes = [] add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) try: - tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy) + tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) except ValueError: if kws: # if keyword arguments are included in the tree, we make adapt the error diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index a8d3f9f87..de54d3da1 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -22,9 +22,8 @@ from typing import (Callable, Generic, Optional, Sequence, Tuple, List, TypeVar, from jax import core from jax import linear_util as lu from jax.custom_transpose import custom_transpose -from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, - tree_multimap, treedef_is_leaf, treedef_tuple, - register_pytree_node_class) +from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, + treedef_tuple, register_pytree_node_class) from jax._src import custom_api_util from jax._src import dtypes from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable @@ -198,7 +197,7 @@ class custom_jvp(Generic[ReturnValue]): zeros = _zeros_like_pytree(primal_out) all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros for t, jvp in zip(tangents, jvps)] - tangent_out = tree_multimap(_sum_tangents, primal_out, *all_tangents_out) + tangent_out = tree_map(_sum_tangents, primal_out, *all_tangents_out) return primal_out, tangent_out self.defjvp(jvp) @@ -617,8 +616,8 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args): try: if not isinstance(py_cts_in, tuple): raise ValueError - tree_multimap(append_cts, - tuple(zero if ct is None else ct for ct in py_cts_in), dummy) + tree_map(append_cts, + tuple(zero if ct is None else ct for ct in py_cts_in), dummy) except ValueError: _, in_tree2 = tree_flatten(py_cts_in) msg = ("Custom VJP rule must produce an output with the same container " diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 4b0858c63..c1dd3a4f4 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -54,7 +54,7 @@ from jax._src.traceback_util import api_boundary from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, cache, extend_name_stack, wrap_name) from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, - treedef_children, treedef_tuple, tree_multimap, + treedef_children, treedef_tuple, tree_map, tree_leaves, tree_structure) from jax._src import ad_util from jax.config import config @@ -1477,7 +1477,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], ys.append(y) stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit else jax.numpy.stack((y, *ys))) - stacked_y = tree_multimap(stack, *maybe_reversed(ys)) + stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat] @@ -2196,8 +2196,8 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2): raise TypeError( f"{what} must have same type structure, got {tree1} and {tree2}.") if not all(_map(core.typematch, avals1, avals2)): - diff = tree_multimap(_show_diff, tree_unflatten(tree1, avals1), - tree_unflatten(tree2, avals2)) + diff = tree_map(_show_diff, tree_unflatten(tree1, avals1), + tree_unflatten(tree2, avals2)) raise TypeError(f"{what} must have identical types, got\n{diff}.") diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 8d299030a..3dab02f94 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -21,7 +21,7 @@ import jax.numpy as jnp from jax import device_put from jax import lax from jax import scipy as jsp -from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure, +from jax.tree_util import (tree_leaves, tree_map, tree_structure, tree_reduce, Partial) from jax._src import dtypes @@ -52,11 +52,11 @@ def _vdot_real_part(x, y): def _vdot_real_tree(x, y): - return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y))) + return sum(tree_leaves(tree_map(_vdot_real_part, x, y))) def _vdot_tree(x, y): - return sum(tree_leaves(tree_multimap(partial( + return sum(tree_leaves(tree_map(partial( jnp.vdot, precision=lax.Precision.HIGHEST), x, y))) @@ -73,9 +73,9 @@ def _div(tree, scalar): return tree_map(partial(lambda v: v / scalar), tree) -_add = partial(tree_multimap, operator.add) -_sub = partial(tree_multimap, operator.sub) -_dot_tree = partial(tree_multimap, _dot) +_add = partial(tree_map, operator.add) +_sub = partial(tree_map, operator.sub) +_dot_tree = partial(tree_map, _dot) @Partial @@ -162,12 +162,12 @@ def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity): shat = M(s) t = A(shat) omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases? - x_ = tree_multimap(partial(jnp.where, exit_early), - _add(x, _mul(alpha_, phat)), - _add(x, _add(_mul(alpha_, phat), _mul(omega_, shat))) - ) - r_ = tree_multimap(partial(jnp.where, exit_early), - s, _sub(s, _mul(omega_, t))) + x_ = tree_map(partial(jnp.where, exit_early), + _add(x, _mul(alpha_, phat)), + _add(x, _add(_mul(alpha_, phat), _mul(omega_, shat))) + ) + r_ = tree_map(partial(jnp.where, exit_early), + s, _sub(s, _mul(omega_, t))) k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1) k_ = jnp.where((rho_ == 0), -10, k_) return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_ @@ -308,7 +308,7 @@ def _project_on_columns(A, v): """ Returns A.T.conj() @ v. """ - v_proj = tree_multimap( + v_proj = tree_map( lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v, ) return tree_reduce(operator.add, v_proj) @@ -400,7 +400,7 @@ def _kth_arnoldi_iteration(k, A, M, V, H): tol = eps * v_norm_0 unit_v, v_norm_1 = _safe_normalize(v, thresh=tol) - V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v) + V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v) h = h.at[k + 1].set(v_norm_1) H = H.at[k, :].set(h) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 0953299bf..799395a2a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -37,7 +37,7 @@ from jax._src import dtypes as _dtypes from jax import lax from jax._src.config import flags, bool_env, config from jax._src.util import prod, unzip2 -from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce +from jax.tree_util import tree_map, tree_all, tree_reduce from jax._src.lib import xla_bridge from jax._src import dispatch from jax.interpreters import mlir @@ -222,12 +222,12 @@ def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''): def check_eq(xs, ys, err_msg=''): assert_close = partial(_assert_numpy_allclose, err_msg=err_msg) - tree_all(tree_multimap(assert_close, xs, ys)) + tree_all(tree_map(assert_close, xs, ys)) def check_close(xs, ys, atol=None, rtol=None, err_msg=''): assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol, err_msg=err_msg) - tree_all(tree_multimap(assert_close, xs, ys)) + tree_all(tree_map(assert_close, xs, ys)) def _check_dtypes_match(xs, ys): def _assert_dtypes_match(x, y): @@ -236,13 +236,13 @@ def _check_dtypes_match(xs, ys): else: assert (_dtypes.canonicalize_dtype(_dtype(x)) == _dtypes.canonicalize_dtype(_dtype(y))) - tree_all(tree_multimap(_assert_dtypes_match, xs, ys)) + tree_all(tree_map(_assert_dtypes_match, xs, ys)) def inner_prod(xs, ys): def contract(x, y): return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1))) - return tree_reduce(np.add, tree_multimap(contract, xs, ys)) + return tree_reduce(np.add, tree_map(contract, xs, ys)) def _safe_subtract(x, y, *, dtype): @@ -251,9 +251,9 @@ def _safe_subtract(x, y, *, dtype): return np.where(np.equal(x, y), np.array(0, dtype), np.subtract(x, y, dtype=dtype)) -add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x))) -sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x))) -safe_sub = partial(tree_multimap, +add = partial(tree_map, lambda x, y: np.add(x, y, dtype=_dtype(x))) +sub = partial(tree_map, lambda x, y: np.subtract(x, y, dtype=_dtype(x))) +safe_sub = partial(tree_map, lambda x, y: _safe_subtract(x, y, dtype=_dtype(x))) conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x))) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index b8d67333c..57bbcda7c 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -20,6 +20,7 @@ import operator as op from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, List, Dict, Type, TypeVar, overload, TYPE_CHECKING, NamedTuple) import textwrap +import warnings from jax._src.lib import pytree @@ -179,7 +180,11 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any, all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) -tree_multimap = tree_map +def tree_multimap(*args, **kwargs): + """Deprecated alias of :func:`jax.tree_util.tree_map`""" + warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() ' + 'instead as a drop-in replacement.', FutureWarning) + return tree_map(*args, **kwargs) # TODO(mattjj,phawkins): consider removing this function def _process_pytree(process_node, tree): diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 2ad4f9492..eeeabbccd 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -421,7 +421,7 @@ class JaxToTfTestCase(jtu.JaxTestCase): for dim_spec in in_spec), dtype=tf.float32) - return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes) + return tree_util.tree_map(polymorphic_shape_to_tensorspec, polymorphic_shapes) def CountLargeTfConstants(self, tf_fun: Callable, *args, at_least=256): diff --git a/jax/tree_util.py b/jax/tree_util.py index 49fdb3d3f..023d7dc17 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -47,7 +47,8 @@ from jax._src.tree_util import ( tree_flatten as tree_flatten, tree_leaves as tree_leaves, tree_map as tree_map, - tree_multimap as tree_multimap, + # TODO(jakevdp) remove tree_multimap once deprecation is complete. + tree_multimap, tree_reduce as tree_reduce, tree_structure as tree_structure, tree_transpose as tree_transpose, diff --git a/tests/core_test.py b/tests/core_test.py index 9abe707d5..e5487c12c 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -30,7 +30,7 @@ from jax import numpy as jnp from jax import linear_util as lu from jax import jvp, linearize, vjp, jit, make_jaxpr from jax.core import UnshapedArray, ShapedArray -from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves +from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce, tree_leaves from jax.interpreters import partial_eval as pe from jax._src import test_util as jtu @@ -147,16 +147,16 @@ def fwd_deriv(f): class CoreTest(jtu.JaxTestCase): - def test_tree_multimap(self): + def test_tree_map(self): xs = ({'a': 1}, [2, 3]) ys = ({'a': 10}, [20, 30]) ys_bad = ({'a': 10, 'b': 10}, [20, 30]) zs = ({'a': 11}, [22, 33]) f = lambda x, y: x + y - assert tree_multimap(f, xs, ys) == zs + assert tree_map(f, xs, ys) == zs try: - tree_multimap(f, xs, ys_bad) + tree_map(f, xs, ys_bad) assert False except (TypeError, ValueError): pass @@ -170,7 +170,7 @@ class CoreTest(jtu.JaxTestCase): flat, treedef = tree_flatten(tree) assert flat == [1, 2, 3, 4, 5] tree2 = tree_unflatten(treedef, flat) - nodes_equal = tree_multimap(operator.eq, tree, tree2) + nodes_equal = tree_map(operator.eq, tree, tree2) assert tree_reduce(operator.and_, nodes_equal) @parameterized.named_parameters( diff --git a/tests/lax_test.py b/tests/lax_test.py index 72fdb86c3..bc05dfefd 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2570,7 +2570,7 @@ class LaxTest(jtu.JaxTestCase): operands = {'x': [np.ones(5), np.arange(5)]} init_values = {'x': [0., 0]} result = lax.reduce(operands, init_values, - lambda x, y: tree_util.tree_multimap(lax.add, x, y), + lambda x, y: tree_util.tree_map(lax.add, x, y), [0]) self.assertDictEqual(result, {'x': [5., 10.]}) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index f3e45deeb..8d973e8fe 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -354,7 +354,7 @@ class PythonPmapTest(jtu.JaxTestCase): 'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]), 'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])} - assert_allclose = partial(tree_util.tree_multimap, + assert_allclose = partial(tree_util.tree_map, partial(self.assertAllClose, check_dtypes=False)) assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x)) assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x)) @@ -372,7 +372,7 @@ class PythonPmapTest(jtu.JaxTestCase): tree_f = lambda f: partial(tree_util.tree_map, f) jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i') np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape)) - assert_allclose = partial(tree_util.tree_multimap, + assert_allclose = partial(tree_util.tree_map, partial(self.assertAllClose, check_dtypes=False)) assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x)) assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x)) @@ -2381,7 +2381,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): return {'a': x} device_count = jax.device_count() x = jnp.arange(device_count) - tree_util.tree_multimap(self.assertAllClose, f(x), {'a': x}) + tree_util.tree_map(self.assertAllClose, f(x), {'a': x}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_{in_axes}_{out_axes}", diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index fc97b6820..84e15fbb2 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -242,18 +242,18 @@ class TreeTest(jtu.JaxTestCase): }, (3, 4)), None, ATuple(foo=(11, 9), bar=None)]) self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None]) - def testTreeMultimap(self): + def testTreeMap(self): x = ((1, 2), (3, 4, 5)) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) - out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y) + out = tree_util.tree_map(lambda *xs: tuple(xs), x, y) self.assertEqual(out, (((1, [3]), (2, None)), ((3, {"foo": "bar"}), (4, 7), (5, [5, 6])))) - def testTreeMultimapWithIsLeafArgument(self): + def testTreeMapWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], None), ({"foo": "bar"}, 7, [5, 6])) - out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y, - is_leaf=lambda n: isinstance(n, list)) + out = tree_util.tree_map(lambda *xs: tuple(xs), x, y, + is_leaf=lambda n: isinstance(n, list)) self.assertEqual(out, (((1, [3]), (2, None)), (([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))