mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate jax.tree_util.tree_multimap
This commit is contained in:
parent
1c3edc811d
commit
df1ceaeeb1
@ -20,6 +20,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* `DeviceArray.copy()` now returns a `DeviceArray` rather than a `np.ndarray` ({jax-issue}`#10069`)
|
||||
* Deprecations:
|
||||
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
|
||||
* {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`).
|
||||
|
||||
## jaxlib 0.3.3 (Unreleased)
|
||||
|
||||
|
@ -43,9 +43,9 @@ from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
from jax.core import eval_jaxpr
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
from jax.tree_util import (tree_map, tree_multimap, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_transpose, tree_leaves,
|
||||
tree_multimap, treedef_is_leaf, treedef_children,
|
||||
tree_map, treedef_is_leaf, treedef_children,
|
||||
Partial, PyTreeDef, all_leaves, treedef_tuple)
|
||||
|
||||
from jax._src import device_array
|
||||
@ -2731,7 +2731,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
|
||||
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
return tree_multimap(_device_put_sharded, *shards)
|
||||
return tree_map(_device_put_sharded, *shards)
|
||||
|
||||
|
||||
def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
|
||||
|
@ -21,7 +21,7 @@ import numpy as np
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.tree_util import (
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure,
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
|
||||
treedef_children, treedef_is_leaf)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax import linear_util as lu
|
||||
@ -286,7 +286,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
||||
axes = []
|
||||
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
||||
try:
|
||||
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
||||
tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
||||
except ValueError:
|
||||
if kws:
|
||||
# if keyword arguments are included in the tree, we make adapt the error
|
||||
|
@ -22,9 +22,8 @@ from typing import (Callable, Generic, Optional, Sequence, Tuple, List, TypeVar,
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
tree_multimap, treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, treedef_is_leaf,
|
||||
treedef_tuple, register_pytree_node_class)
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
@ -198,7 +197,7 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
zeros = _zeros_like_pytree(primal_out)
|
||||
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
|
||||
for t, jvp in zip(tangents, jvps)]
|
||||
tangent_out = tree_multimap(_sum_tangents, primal_out, *all_tangents_out)
|
||||
tangent_out = tree_map(_sum_tangents, primal_out, *all_tangents_out)
|
||||
return primal_out, tangent_out
|
||||
|
||||
self.defjvp(jvp)
|
||||
@ -617,7 +616,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
try:
|
||||
if not isinstance(py_cts_in, tuple):
|
||||
raise ValueError
|
||||
tree_multimap(append_cts,
|
||||
tree_map(append_cts,
|
||||
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
|
||||
except ValueError:
|
||||
_, in_tree2 = tree_flatten(py_cts_in)
|
||||
|
@ -54,7 +54,7 @@ from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
|
||||
split_list, cache, extend_name_stack, wrap_name)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
treedef_children, treedef_tuple, tree_multimap,
|
||||
treedef_children, treedef_tuple, tree_map,
|
||||
tree_leaves, tree_structure)
|
||||
from jax._src import ad_util
|
||||
from jax.config import config
|
||||
@ -1477,7 +1477,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
ys.append(y)
|
||||
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
|
||||
else jax.numpy.stack((y, *ys)))
|
||||
stacked_y = tree_multimap(stack, *maybe_reversed(ys))
|
||||
stacked_y = tree_map(stack, *maybe_reversed(ys))
|
||||
return carry, stacked_y
|
||||
|
||||
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
||||
@ -2196,7 +2196,7 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
||||
raise TypeError(
|
||||
f"{what} must have same type structure, got {tree1} and {tree2}.")
|
||||
if not all(_map(core.typematch, avals1, avals2)):
|
||||
diff = tree_multimap(_show_diff, tree_unflatten(tree1, avals1),
|
||||
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
|
||||
tree_unflatten(tree2, avals2))
|
||||
raise TypeError(f"{what} must have identical types, got\n{diff}.")
|
||||
|
||||
|
@ -21,7 +21,7 @@ import jax.numpy as jnp
|
||||
from jax import device_put
|
||||
from jax import lax
|
||||
from jax import scipy as jsp
|
||||
from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,
|
||||
from jax.tree_util import (tree_leaves, tree_map, tree_structure,
|
||||
tree_reduce, Partial)
|
||||
|
||||
from jax._src import dtypes
|
||||
@ -52,11 +52,11 @@ def _vdot_real_part(x, y):
|
||||
|
||||
|
||||
def _vdot_real_tree(x, y):
|
||||
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
|
||||
return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))
|
||||
|
||||
|
||||
def _vdot_tree(x, y):
|
||||
return sum(tree_leaves(tree_multimap(partial(
|
||||
return sum(tree_leaves(tree_map(partial(
|
||||
jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))
|
||||
|
||||
|
||||
@ -73,9 +73,9 @@ def _div(tree, scalar):
|
||||
return tree_map(partial(lambda v: v / scalar), tree)
|
||||
|
||||
|
||||
_add = partial(tree_multimap, operator.add)
|
||||
_sub = partial(tree_multimap, operator.sub)
|
||||
_dot_tree = partial(tree_multimap, _dot)
|
||||
_add = partial(tree_map, operator.add)
|
||||
_sub = partial(tree_map, operator.sub)
|
||||
_dot_tree = partial(tree_map, _dot)
|
||||
|
||||
|
||||
@Partial
|
||||
@ -162,11 +162,11 @@ def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
shat = M(s)
|
||||
t = A(shat)
|
||||
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases?
|
||||
x_ = tree_multimap(partial(jnp.where, exit_early),
|
||||
x_ = tree_map(partial(jnp.where, exit_early),
|
||||
_add(x, _mul(alpha_, phat)),
|
||||
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
|
||||
)
|
||||
r_ = tree_multimap(partial(jnp.where, exit_early),
|
||||
r_ = tree_map(partial(jnp.where, exit_early),
|
||||
s, _sub(s, _mul(omega_, t)))
|
||||
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
|
||||
k_ = jnp.where((rho_ == 0), -10, k_)
|
||||
@ -308,7 +308,7 @@ def _project_on_columns(A, v):
|
||||
"""
|
||||
Returns A.T.conj() @ v.
|
||||
"""
|
||||
v_proj = tree_multimap(
|
||||
v_proj = tree_map(
|
||||
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
|
||||
)
|
||||
return tree_reduce(operator.add, v_proj)
|
||||
@ -400,7 +400,7 @@ def _kth_arnoldi_iteration(k, A, M, V, H):
|
||||
|
||||
tol = eps * v_norm_0
|
||||
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
|
||||
V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
|
||||
V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
|
||||
|
||||
h = h.at[k + 1].set(v_norm_1)
|
||||
H = H.at[k, :].set(h)
|
||||
|
@ -37,7 +37,7 @@ from jax._src import dtypes as _dtypes
|
||||
from jax import lax
|
||||
from jax._src.config import flags, bool_env, config
|
||||
from jax._src.util import prod, unzip2
|
||||
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from jax.tree_util import tree_map, tree_all, tree_reduce
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src import dispatch
|
||||
from jax.interpreters import mlir
|
||||
@ -222,12 +222,12 @@ def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''):
|
||||
|
||||
def check_eq(xs, ys, err_msg=''):
|
||||
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
|
||||
tree_all(tree_multimap(assert_close, xs, ys))
|
||||
tree_all(tree_map(assert_close, xs, ys))
|
||||
|
||||
def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
|
||||
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol,
|
||||
err_msg=err_msg)
|
||||
tree_all(tree_multimap(assert_close, xs, ys))
|
||||
tree_all(tree_map(assert_close, xs, ys))
|
||||
|
||||
def _check_dtypes_match(xs, ys):
|
||||
def _assert_dtypes_match(x, y):
|
||||
@ -236,13 +236,13 @@ def _check_dtypes_match(xs, ys):
|
||||
else:
|
||||
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
|
||||
_dtypes.canonicalize_dtype(_dtype(y)))
|
||||
tree_all(tree_multimap(_assert_dtypes_match, xs, ys))
|
||||
tree_all(tree_map(_assert_dtypes_match, xs, ys))
|
||||
|
||||
|
||||
def inner_prod(xs, ys):
|
||||
def contract(x, y):
|
||||
return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1)))
|
||||
return tree_reduce(np.add, tree_multimap(contract, xs, ys))
|
||||
return tree_reduce(np.add, tree_map(contract, xs, ys))
|
||||
|
||||
|
||||
def _safe_subtract(x, y, *, dtype):
|
||||
@ -251,9 +251,9 @@ def _safe_subtract(x, y, *, dtype):
|
||||
return np.where(np.equal(x, y), np.array(0, dtype),
|
||||
np.subtract(x, y, dtype=dtype))
|
||||
|
||||
add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x)))
|
||||
sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
|
||||
safe_sub = partial(tree_multimap,
|
||||
add = partial(tree_map, lambda x, y: np.add(x, y, dtype=_dtype(x)))
|
||||
sub = partial(tree_map, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
|
||||
safe_sub = partial(tree_map,
|
||||
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
|
||||
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))
|
||||
|
||||
|
@ -20,6 +20,7 @@ import operator as op
|
||||
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, List,
|
||||
Dict, Type, TypeVar, overload, TYPE_CHECKING, NamedTuple)
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
from jax._src.lib import pytree
|
||||
|
||||
@ -179,7 +180,11 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
|
||||
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
||||
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
||||
|
||||
tree_multimap = tree_map
|
||||
def tree_multimap(*args, **kwargs):
|
||||
"""Deprecated alias of :func:`jax.tree_util.tree_map`"""
|
||||
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
|
||||
'instead as a drop-in replacement.', FutureWarning)
|
||||
return tree_map(*args, **kwargs)
|
||||
|
||||
# TODO(mattjj,phawkins): consider removing this function
|
||||
def _process_pytree(process_node, tree):
|
||||
|
@ -421,7 +421,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
for dim_spec in in_spec),
|
||||
dtype=tf.float32)
|
||||
|
||||
return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)
|
||||
return tree_util.tree_map(polymorphic_shape_to_tensorspec, polymorphic_shapes)
|
||||
|
||||
def CountLargeTfConstants(self, tf_fun: Callable, *args,
|
||||
at_least=256):
|
||||
|
@ -47,7 +47,8 @@ from jax._src.tree_util import (
|
||||
tree_flatten as tree_flatten,
|
||||
tree_leaves as tree_leaves,
|
||||
tree_map as tree_map,
|
||||
tree_multimap as tree_multimap,
|
||||
# TODO(jakevdp) remove tree_multimap once deprecation is complete.
|
||||
tree_multimap,
|
||||
tree_reduce as tree_reduce,
|
||||
tree_structure as tree_structure,
|
||||
tree_transpose as tree_transpose,
|
||||
|
@ -30,7 +30,7 @@ from jax import numpy as jnp
|
||||
from jax import linear_util as lu
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.core import UnshapedArray, ShapedArray
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce, tree_leaves
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import test_util as jtu
|
||||
@ -147,16 +147,16 @@ def fwd_deriv(f):
|
||||
|
||||
class CoreTest(jtu.JaxTestCase):
|
||||
|
||||
def test_tree_multimap(self):
|
||||
def test_tree_map(self):
|
||||
xs = ({'a': 1}, [2, 3])
|
||||
ys = ({'a': 10}, [20, 30])
|
||||
ys_bad = ({'a': 10, 'b': 10}, [20, 30])
|
||||
zs = ({'a': 11}, [22, 33])
|
||||
|
||||
f = lambda x, y: x + y
|
||||
assert tree_multimap(f, xs, ys) == zs
|
||||
assert tree_map(f, xs, ys) == zs
|
||||
try:
|
||||
tree_multimap(f, xs, ys_bad)
|
||||
tree_map(f, xs, ys_bad)
|
||||
assert False
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
@ -170,7 +170,7 @@ class CoreTest(jtu.JaxTestCase):
|
||||
flat, treedef = tree_flatten(tree)
|
||||
assert flat == [1, 2, 3, 4, 5]
|
||||
tree2 = tree_unflatten(treedef, flat)
|
||||
nodes_equal = tree_multimap(operator.eq, tree, tree2)
|
||||
nodes_equal = tree_map(operator.eq, tree, tree2)
|
||||
assert tree_reduce(operator.and_, nodes_equal)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -2570,7 +2570,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
operands = {'x': [np.ones(5), np.arange(5)]}
|
||||
init_values = {'x': [0., 0]}
|
||||
result = lax.reduce(operands, init_values,
|
||||
lambda x, y: tree_util.tree_multimap(lax.add, x, y),
|
||||
lambda x, y: tree_util.tree_map(lax.add, x, y),
|
||||
[0])
|
||||
self.assertDictEqual(result, {'x': [5., 10.]})
|
||||
|
||||
|
@ -354,7 +354,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
|
||||
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
|
||||
|
||||
assert_allclose = partial(tree_util.tree_multimap,
|
||||
assert_allclose = partial(tree_util.tree_map,
|
||||
partial(self.assertAllClose, check_dtypes=False))
|
||||
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
||||
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
||||
@ -372,7 +372,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
tree_f = lambda f: partial(tree_util.tree_map, f)
|
||||
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
|
||||
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
|
||||
assert_allclose = partial(tree_util.tree_multimap,
|
||||
assert_allclose = partial(tree_util.tree_map,
|
||||
partial(self.assertAllClose, check_dtypes=False))
|
||||
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
||||
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
||||
@ -2381,7 +2381,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
return {'a': x}
|
||||
device_count = jax.device_count()
|
||||
x = jnp.arange(device_count)
|
||||
tree_util.tree_multimap(self.assertAllClose, f(x), {'a': x})
|
||||
tree_util.tree_map(self.assertAllClose, f(x), {'a': x})
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{in_axes}_{out_axes}",
|
||||
|
@ -242,17 +242,17 @@ class TreeTest(jtu.JaxTestCase):
|
||||
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
|
||||
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
|
||||
|
||||
def testTreeMultimap(self):
|
||||
def testTreeMap(self):
|
||||
x = ((1, 2), (3, 4, 5))
|
||||
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
||||
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
|
||||
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y)
|
||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
||||
|
||||
def testTreeMultimapWithIsLeafArgument(self):
|
||||
def testTreeMapWithIsLeafArgument(self):
|
||||
x = ((1, 2), [3, 4, 5])
|
||||
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
||||
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y,
|
||||
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y,
|
||||
is_leaf=lambda n: isinstance(n, list))
|
||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
||||
|
Loading…
x
Reference in New Issue
Block a user