mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate jax.tree_util.tree_multimap
This commit is contained in:
parent
1c3edc811d
commit
df1ceaeeb1
@ -20,6 +20,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* `DeviceArray.copy()` now returns a `DeviceArray` rather than a `np.ndarray` ({jax-issue}`#10069`)
|
* `DeviceArray.copy()` now returns a `DeviceArray` rather than a `np.ndarray` ({jax-issue}`#10069`)
|
||||||
* Deprecations:
|
* Deprecations:
|
||||||
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
|
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
|
||||||
|
* {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`#5746`).
|
||||||
|
|
||||||
## jaxlib 0.3.3 (Unreleased)
|
## jaxlib 0.3.3 (Unreleased)
|
||||||
|
|
||||||
|
@ -43,9 +43,9 @@ from jax import core
|
|||||||
from jax import linear_util as lu
|
from jax import linear_util as lu
|
||||||
from jax import stages
|
from jax import stages
|
||||||
from jax.core import eval_jaxpr
|
from jax.core import eval_jaxpr
|
||||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
from jax.tree_util import (tree_map, tree_multimap, tree_flatten, tree_unflatten,
|
||||||
tree_structure, tree_transpose, tree_leaves,
|
tree_structure, tree_transpose, tree_leaves,
|
||||||
tree_multimap, treedef_is_leaf, treedef_children,
|
tree_map, treedef_is_leaf, treedef_children,
|
||||||
Partial, PyTreeDef, all_leaves, treedef_tuple)
|
Partial, PyTreeDef, all_leaves, treedef_tuple)
|
||||||
|
|
||||||
from jax._src import device_array
|
from jax._src import device_array
|
||||||
@ -2731,7 +2731,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
|
|||||||
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
||||||
|
|
||||||
with config_explicit_device_put_scope():
|
with config_explicit_device_put_scope():
|
||||||
return tree_multimap(_device_put_sharded, *shards)
|
return tree_map(_device_put_sharded, *shards)
|
||||||
|
|
||||||
|
|
||||||
def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
|
def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
|
||||||
|
@ -21,7 +21,7 @@ import numpy as np
|
|||||||
from jax import core
|
from jax import core
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src.tree_util import (
|
from jax._src.tree_util import (
|
||||||
PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure,
|
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
|
||||||
treedef_children, treedef_is_leaf)
|
treedef_children, treedef_is_leaf)
|
||||||
from jax._src.tree_util import _replace_nones
|
from jax._src.tree_util import _replace_nones
|
||||||
from jax import linear_util as lu
|
from jax import linear_util as lu
|
||||||
@ -286,7 +286,7 @@ def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
|||||||
axes = []
|
axes = []
|
||||||
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
||||||
try:
|
try:
|
||||||
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if kws:
|
if kws:
|
||||||
# if keyword arguments are included in the tree, we make adapt the error
|
# if keyword arguments are included in the tree, we make adapt the error
|
||||||
|
@ -22,9 +22,8 @@ from typing import (Callable, Generic, Optional, Sequence, Tuple, List, TypeVar,
|
|||||||
from jax import core
|
from jax import core
|
||||||
from jax import linear_util as lu
|
from jax import linear_util as lu
|
||||||
from jax.custom_transpose import custom_transpose
|
from jax.custom_transpose import custom_transpose
|
||||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, treedef_is_leaf,
|
||||||
tree_multimap, treedef_is_leaf, treedef_tuple,
|
treedef_tuple, register_pytree_node_class)
|
||||||
register_pytree_node_class)
|
|
||||||
from jax._src import custom_api_util
|
from jax._src import custom_api_util
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||||
@ -198,7 +197,7 @@ class custom_jvp(Generic[ReturnValue]):
|
|||||||
zeros = _zeros_like_pytree(primal_out)
|
zeros = _zeros_like_pytree(primal_out)
|
||||||
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
|
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
|
||||||
for t, jvp in zip(tangents, jvps)]
|
for t, jvp in zip(tangents, jvps)]
|
||||||
tangent_out = tree_multimap(_sum_tangents, primal_out, *all_tangents_out)
|
tangent_out = tree_map(_sum_tangents, primal_out, *all_tangents_out)
|
||||||
return primal_out, tangent_out
|
return primal_out, tangent_out
|
||||||
|
|
||||||
self.defjvp(jvp)
|
self.defjvp(jvp)
|
||||||
@ -617,8 +616,8 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
|||||||
try:
|
try:
|
||||||
if not isinstance(py_cts_in, tuple):
|
if not isinstance(py_cts_in, tuple):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
tree_multimap(append_cts,
|
tree_map(append_cts,
|
||||||
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
|
tuple(zero if ct is None else ct for ct in py_cts_in), dummy)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
_, in_tree2 = tree_flatten(py_cts_in)
|
_, in_tree2 = tree_flatten(py_cts_in)
|
||||||
msg = ("Custom VJP rule must produce an output with the same container "
|
msg = ("Custom VJP rule must produce an output with the same container "
|
||||||
|
@ -54,7 +54,7 @@ from jax._src.traceback_util import api_boundary
|
|||||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
|
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
|
||||||
split_list, cache, extend_name_stack, wrap_name)
|
split_list, cache, extend_name_stack, wrap_name)
|
||||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||||
treedef_children, treedef_tuple, tree_multimap,
|
treedef_children, treedef_tuple, tree_map,
|
||||||
tree_leaves, tree_structure)
|
tree_leaves, tree_structure)
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
@ -1477,7 +1477,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
|||||||
ys.append(y)
|
ys.append(y)
|
||||||
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
|
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
|
||||||
else jax.numpy.stack((y, *ys)))
|
else jax.numpy.stack((y, *ys)))
|
||||||
stacked_y = tree_multimap(stack, *maybe_reversed(ys))
|
stacked_y = tree_map(stack, *maybe_reversed(ys))
|
||||||
return carry, stacked_y
|
return carry, stacked_y
|
||||||
|
|
||||||
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
||||||
@ -2196,8 +2196,8 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{what} must have same type structure, got {tree1} and {tree2}.")
|
f"{what} must have same type structure, got {tree1} and {tree2}.")
|
||||||
if not all(_map(core.typematch, avals1, avals2)):
|
if not all(_map(core.typematch, avals1, avals2)):
|
||||||
diff = tree_multimap(_show_diff, tree_unflatten(tree1, avals1),
|
diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
|
||||||
tree_unflatten(tree2, avals2))
|
tree_unflatten(tree2, avals2))
|
||||||
raise TypeError(f"{what} must have identical types, got\n{diff}.")
|
raise TypeError(f"{what} must have identical types, got\n{diff}.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ import jax.numpy as jnp
|
|||||||
from jax import device_put
|
from jax import device_put
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import scipy as jsp
|
from jax import scipy as jsp
|
||||||
from jax.tree_util import (tree_leaves, tree_map, tree_multimap, tree_structure,
|
from jax.tree_util import (tree_leaves, tree_map, tree_structure,
|
||||||
tree_reduce, Partial)
|
tree_reduce, Partial)
|
||||||
|
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
@ -52,11 +52,11 @@ def _vdot_real_part(x, y):
|
|||||||
|
|
||||||
|
|
||||||
def _vdot_real_tree(x, y):
|
def _vdot_real_tree(x, y):
|
||||||
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
|
return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))
|
||||||
|
|
||||||
|
|
||||||
def _vdot_tree(x, y):
|
def _vdot_tree(x, y):
|
||||||
return sum(tree_leaves(tree_multimap(partial(
|
return sum(tree_leaves(tree_map(partial(
|
||||||
jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))
|
jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))
|
||||||
|
|
||||||
|
|
||||||
@ -73,9 +73,9 @@ def _div(tree, scalar):
|
|||||||
return tree_map(partial(lambda v: v / scalar), tree)
|
return tree_map(partial(lambda v: v / scalar), tree)
|
||||||
|
|
||||||
|
|
||||||
_add = partial(tree_multimap, operator.add)
|
_add = partial(tree_map, operator.add)
|
||||||
_sub = partial(tree_multimap, operator.sub)
|
_sub = partial(tree_map, operator.sub)
|
||||||
_dot_tree = partial(tree_multimap, _dot)
|
_dot_tree = partial(tree_map, _dot)
|
||||||
|
|
||||||
|
|
||||||
@Partial
|
@Partial
|
||||||
@ -162,12 +162,12 @@ def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
|||||||
shat = M(s)
|
shat = M(s)
|
||||||
t = A(shat)
|
t = A(shat)
|
||||||
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases?
|
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases?
|
||||||
x_ = tree_multimap(partial(jnp.where, exit_early),
|
x_ = tree_map(partial(jnp.where, exit_early),
|
||||||
_add(x, _mul(alpha_, phat)),
|
_add(x, _mul(alpha_, phat)),
|
||||||
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
|
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
|
||||||
)
|
)
|
||||||
r_ = tree_multimap(partial(jnp.where, exit_early),
|
r_ = tree_map(partial(jnp.where, exit_early),
|
||||||
s, _sub(s, _mul(omega_, t)))
|
s, _sub(s, _mul(omega_, t)))
|
||||||
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
|
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
|
||||||
k_ = jnp.where((rho_ == 0), -10, k_)
|
k_ = jnp.where((rho_ == 0), -10, k_)
|
||||||
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_
|
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_
|
||||||
@ -308,7 +308,7 @@ def _project_on_columns(A, v):
|
|||||||
"""
|
"""
|
||||||
Returns A.T.conj() @ v.
|
Returns A.T.conj() @ v.
|
||||||
"""
|
"""
|
||||||
v_proj = tree_multimap(
|
v_proj = tree_map(
|
||||||
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
|
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
|
||||||
)
|
)
|
||||||
return tree_reduce(operator.add, v_proj)
|
return tree_reduce(operator.add, v_proj)
|
||||||
@ -400,7 +400,7 @@ def _kth_arnoldi_iteration(k, A, M, V, H):
|
|||||||
|
|
||||||
tol = eps * v_norm_0
|
tol = eps * v_norm_0
|
||||||
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
|
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
|
||||||
V = tree_multimap(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
|
V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
|
||||||
|
|
||||||
h = h.at[k + 1].set(v_norm_1)
|
h = h.at[k + 1].set(v_norm_1)
|
||||||
H = H.at[k, :].set(h)
|
H = H.at[k, :].set(h)
|
||||||
|
@ -37,7 +37,7 @@ from jax._src import dtypes as _dtypes
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax._src.config import flags, bool_env, config
|
from jax._src.config import flags, bool_env, config
|
||||||
from jax._src.util import prod, unzip2
|
from jax._src.util import prod, unzip2
|
||||||
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
from jax.tree_util import tree_map, tree_all, tree_reduce
|
||||||
from jax._src.lib import xla_bridge
|
from jax._src.lib import xla_bridge
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax.interpreters import mlir
|
from jax.interpreters import mlir
|
||||||
@ -222,12 +222,12 @@ def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''):
|
|||||||
|
|
||||||
def check_eq(xs, ys, err_msg=''):
|
def check_eq(xs, ys, err_msg=''):
|
||||||
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
|
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
|
||||||
tree_all(tree_multimap(assert_close, xs, ys))
|
tree_all(tree_map(assert_close, xs, ys))
|
||||||
|
|
||||||
def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
|
def check_close(xs, ys, atol=None, rtol=None, err_msg=''):
|
||||||
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol,
|
assert_close = partial(_assert_numpy_close, atol=atol, rtol=rtol,
|
||||||
err_msg=err_msg)
|
err_msg=err_msg)
|
||||||
tree_all(tree_multimap(assert_close, xs, ys))
|
tree_all(tree_map(assert_close, xs, ys))
|
||||||
|
|
||||||
def _check_dtypes_match(xs, ys):
|
def _check_dtypes_match(xs, ys):
|
||||||
def _assert_dtypes_match(x, y):
|
def _assert_dtypes_match(x, y):
|
||||||
@ -236,13 +236,13 @@ def _check_dtypes_match(xs, ys):
|
|||||||
else:
|
else:
|
||||||
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
|
assert (_dtypes.canonicalize_dtype(_dtype(x)) ==
|
||||||
_dtypes.canonicalize_dtype(_dtype(y)))
|
_dtypes.canonicalize_dtype(_dtype(y)))
|
||||||
tree_all(tree_multimap(_assert_dtypes_match, xs, ys))
|
tree_all(tree_map(_assert_dtypes_match, xs, ys))
|
||||||
|
|
||||||
|
|
||||||
def inner_prod(xs, ys):
|
def inner_prod(xs, ys):
|
||||||
def contract(x, y):
|
def contract(x, y):
|
||||||
return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1)))
|
return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1)))
|
||||||
return tree_reduce(np.add, tree_multimap(contract, xs, ys))
|
return tree_reduce(np.add, tree_map(contract, xs, ys))
|
||||||
|
|
||||||
|
|
||||||
def _safe_subtract(x, y, *, dtype):
|
def _safe_subtract(x, y, *, dtype):
|
||||||
@ -251,9 +251,9 @@ def _safe_subtract(x, y, *, dtype):
|
|||||||
return np.where(np.equal(x, y), np.array(0, dtype),
|
return np.where(np.equal(x, y), np.array(0, dtype),
|
||||||
np.subtract(x, y, dtype=dtype))
|
np.subtract(x, y, dtype=dtype))
|
||||||
|
|
||||||
add = partial(tree_multimap, lambda x, y: np.add(x, y, dtype=_dtype(x)))
|
add = partial(tree_map, lambda x, y: np.add(x, y, dtype=_dtype(x)))
|
||||||
sub = partial(tree_multimap, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
|
sub = partial(tree_map, lambda x, y: np.subtract(x, y, dtype=_dtype(x)))
|
||||||
safe_sub = partial(tree_multimap,
|
safe_sub = partial(tree_map,
|
||||||
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
|
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))
|
||||||
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))
|
conj = partial(tree_map, lambda x: np.conj(x, dtype=_dtype(x)))
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ import operator as op
|
|||||||
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, List,
|
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, List,
|
||||||
Dict, Type, TypeVar, overload, TYPE_CHECKING, NamedTuple)
|
Dict, Type, TypeVar, overload, TYPE_CHECKING, NamedTuple)
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import warnings
|
||||||
|
|
||||||
from jax._src.lib import pytree
|
from jax._src.lib import pytree
|
||||||
|
|
||||||
@ -179,7 +180,11 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
|
|||||||
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
||||||
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
||||||
|
|
||||||
tree_multimap = tree_map
|
def tree_multimap(*args, **kwargs):
|
||||||
|
"""Deprecated alias of :func:`jax.tree_util.tree_map`"""
|
||||||
|
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
|
||||||
|
'instead as a drop-in replacement.', FutureWarning)
|
||||||
|
return tree_map(*args, **kwargs)
|
||||||
|
|
||||||
# TODO(mattjj,phawkins): consider removing this function
|
# TODO(mattjj,phawkins): consider removing this function
|
||||||
def _process_pytree(process_node, tree):
|
def _process_pytree(process_node, tree):
|
||||||
|
@ -421,7 +421,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
|||||||
for dim_spec in in_spec),
|
for dim_spec in in_spec),
|
||||||
dtype=tf.float32)
|
dtype=tf.float32)
|
||||||
|
|
||||||
return tree_util.tree_multimap(polymorphic_shape_to_tensorspec, polymorphic_shapes)
|
return tree_util.tree_map(polymorphic_shape_to_tensorspec, polymorphic_shapes)
|
||||||
|
|
||||||
def CountLargeTfConstants(self, tf_fun: Callable, *args,
|
def CountLargeTfConstants(self, tf_fun: Callable, *args,
|
||||||
at_least=256):
|
at_least=256):
|
||||||
|
@ -47,7 +47,8 @@ from jax._src.tree_util import (
|
|||||||
tree_flatten as tree_flatten,
|
tree_flatten as tree_flatten,
|
||||||
tree_leaves as tree_leaves,
|
tree_leaves as tree_leaves,
|
||||||
tree_map as tree_map,
|
tree_map as tree_map,
|
||||||
tree_multimap as tree_multimap,
|
# TODO(jakevdp) remove tree_multimap once deprecation is complete.
|
||||||
|
tree_multimap,
|
||||||
tree_reduce as tree_reduce,
|
tree_reduce as tree_reduce,
|
||||||
tree_structure as tree_structure,
|
tree_structure as tree_structure,
|
||||||
tree_transpose as tree_transpose,
|
tree_transpose as tree_transpose,
|
||||||
|
@ -30,7 +30,7 @@ from jax import numpy as jnp
|
|||||||
from jax import linear_util as lu
|
from jax import linear_util as lu
|
||||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||||
from jax.core import UnshapedArray, ShapedArray
|
from jax.core import UnshapedArray, ShapedArray
|
||||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
|
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce, tree_leaves
|
||||||
from jax.interpreters import partial_eval as pe
|
from jax.interpreters import partial_eval as pe
|
||||||
|
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
@ -147,16 +147,16 @@ def fwd_deriv(f):
|
|||||||
|
|
||||||
class CoreTest(jtu.JaxTestCase):
|
class CoreTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
def test_tree_multimap(self):
|
def test_tree_map(self):
|
||||||
xs = ({'a': 1}, [2, 3])
|
xs = ({'a': 1}, [2, 3])
|
||||||
ys = ({'a': 10}, [20, 30])
|
ys = ({'a': 10}, [20, 30])
|
||||||
ys_bad = ({'a': 10, 'b': 10}, [20, 30])
|
ys_bad = ({'a': 10, 'b': 10}, [20, 30])
|
||||||
zs = ({'a': 11}, [22, 33])
|
zs = ({'a': 11}, [22, 33])
|
||||||
|
|
||||||
f = lambda x, y: x + y
|
f = lambda x, y: x + y
|
||||||
assert tree_multimap(f, xs, ys) == zs
|
assert tree_map(f, xs, ys) == zs
|
||||||
try:
|
try:
|
||||||
tree_multimap(f, xs, ys_bad)
|
tree_map(f, xs, ys_bad)
|
||||||
assert False
|
assert False
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
pass
|
pass
|
||||||
@ -170,7 +170,7 @@ class CoreTest(jtu.JaxTestCase):
|
|||||||
flat, treedef = tree_flatten(tree)
|
flat, treedef = tree_flatten(tree)
|
||||||
assert flat == [1, 2, 3, 4, 5]
|
assert flat == [1, 2, 3, 4, 5]
|
||||||
tree2 = tree_unflatten(treedef, flat)
|
tree2 = tree_unflatten(treedef, flat)
|
||||||
nodes_equal = tree_multimap(operator.eq, tree, tree2)
|
nodes_equal = tree_map(operator.eq, tree, tree2)
|
||||||
assert tree_reduce(operator.and_, nodes_equal)
|
assert tree_reduce(operator.and_, nodes_equal)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
|
@ -2570,7 +2570,7 @@ class LaxTest(jtu.JaxTestCase):
|
|||||||
operands = {'x': [np.ones(5), np.arange(5)]}
|
operands = {'x': [np.ones(5), np.arange(5)]}
|
||||||
init_values = {'x': [0., 0]}
|
init_values = {'x': [0., 0]}
|
||||||
result = lax.reduce(operands, init_values,
|
result = lax.reduce(operands, init_values,
|
||||||
lambda x, y: tree_util.tree_multimap(lax.add, x, y),
|
lambda x, y: tree_util.tree_map(lax.add, x, y),
|
||||||
[0])
|
[0])
|
||||||
self.assertDictEqual(result, {'x': [5., 10.]})
|
self.assertDictEqual(result, {'x': [5., 10.]})
|
||||||
|
|
||||||
|
@ -354,7 +354,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
|||||||
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
|
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
|
||||||
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
|
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
|
||||||
|
|
||||||
assert_allclose = partial(tree_util.tree_multimap,
|
assert_allclose = partial(tree_util.tree_map,
|
||||||
partial(self.assertAllClose, check_dtypes=False))
|
partial(self.assertAllClose, check_dtypes=False))
|
||||||
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
||||||
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
||||||
@ -372,7 +372,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
|||||||
tree_f = lambda f: partial(tree_util.tree_map, f)
|
tree_f = lambda f: partial(tree_util.tree_map, f)
|
||||||
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
|
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
|
||||||
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
|
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
|
||||||
assert_allclose = partial(tree_util.tree_multimap,
|
assert_allclose = partial(tree_util.tree_map,
|
||||||
partial(self.assertAllClose, check_dtypes=False))
|
partial(self.assertAllClose, check_dtypes=False))
|
||||||
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
||||||
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
||||||
@ -2381,7 +2381,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
|||||||
return {'a': x}
|
return {'a': x}
|
||||||
device_count = jax.device_count()
|
device_count = jax.device_count()
|
||||||
x = jnp.arange(device_count)
|
x = jnp.arange(device_count)
|
||||||
tree_util.tree_multimap(self.assertAllClose, f(x), {'a': x})
|
tree_util.tree_map(self.assertAllClose, f(x), {'a': x})
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": f"_{in_axes}_{out_axes}",
|
{"testcase_name": f"_{in_axes}_{out_axes}",
|
||||||
|
@ -242,18 +242,18 @@ class TreeTest(jtu.JaxTestCase):
|
|||||||
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
|
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
|
||||||
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
|
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
|
||||||
|
|
||||||
def testTreeMultimap(self):
|
def testTreeMap(self):
|
||||||
x = ((1, 2), (3, 4, 5))
|
x = ((1, 2), (3, 4, 5))
|
||||||
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
||||||
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y)
|
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y)
|
||||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||||
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
||||||
|
|
||||||
def testTreeMultimapWithIsLeafArgument(self):
|
def testTreeMapWithIsLeafArgument(self):
|
||||||
x = ((1, 2), [3, 4, 5])
|
x = ((1, 2), [3, 4, 5])
|
||||||
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
||||||
out = tree_util.tree_multimap(lambda *xs: tuple(xs), x, y,
|
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y,
|
||||||
is_leaf=lambda n: isinstance(n, list))
|
is_leaf=lambda n: isinstance(n, list))
|
||||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||||
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user