tests: access tree utilities via jax.tree.*

This commit is contained in:
Jake VanderPlas 2024-02-26 14:17:18 -08:00
parent 57e34e1a2c
commit cddee4654c
24 changed files with 87 additions and 102 deletions

View File

@ -1042,7 +1042,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertEqual(
obj.in_avals,
((core.ShapedArray([], expected_dtype, weak_type=True),), {}))
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_tree, jax.tree.flatten(((0,), {}))[1])
def test_jit_lower_duck_typing(self):
f_jit = jit(lambda x: 2 * x)
@ -2490,7 +2490,7 @@ class APITest(jtu.JaxTestCase):
x = (jnp.ones(2), jnp.ones(2))
y = 3.
out_shape = api.eval_shape(fun, x, y)
out_shape = tree_util.tree_map(np.shape, out_shape)
out_shape = jax.tree.map(np.shape, out_shape)
self.assertEqual(out_shape, {'hi': (2,)})
@ -3004,7 +3004,7 @@ class APITest(jtu.JaxTestCase):
ValueError,
"vmap in_axes specification must be a tree prefix of the corresponding "
r"value, got specification \(\[0\],\) for value tree "
+ re.escape(f"{tree_util.tree_structure((value_tree,))}."),
+ re.escape(f"{jax.tree.structure((value_tree,))}."),
lambda: api.vmap(lambda x: x, in_axes=([0],))(value_tree)
)
@ -7013,8 +7013,8 @@ class CustomJVPTest(jtu.JaxTestCase):
"must produce primal and tangent outputs "
"with equal container (pytree) structures, but got "
"{} and {} respectively.".format(
tree_util.tree_structure((1,)),
tree_util.tree_structure([1, 2]))
jax.tree.structure((1,)),
jax.tree.structure([1, 2]))
),
lambda: api.jvp(f, (2.,), (1.,)))
@ -7729,9 +7729,9 @@ class CustomJVPTest(jtu.JaxTestCase):
def _vmap(fun):
def _fun(*args):
args = tree_util.tree_map(_pack, args)
args = jax.tree.map(_pack, args)
out = jax.vmap(fun)(*args)
out = tree_util.tree_map(_unpack, out)
out = jax.tree.map(_unpack, out)
return out
return _fun
@ -8242,8 +8242,8 @@ class CustomVJPTest(jtu.JaxTestCase):
"and in particular must produce a tuple of length equal to the "
"number of arguments to the primal function, but got VJP output "
"structure {} for primal input structure {}.".format(
tree_util.tree_structure((1, 1)),
tree_util.tree_structure((1,)))
jax.tree.structure((1, 1)),
jax.tree.structure((1,)))
),
lambda: api.grad(f)(2.))
@ -9017,9 +9017,9 @@ class CustomVJPTest(jtu.JaxTestCase):
def _vmap(fun):
def _fun(*args):
args = tree_util.tree_map(_pack, args)
args = jax.tree.map(_pack, args)
out = jax.vmap(fun)(*args)
out = tree_util.tree_map(_unpack, out)
out = jax.tree.map(_unpack, out)
return out
return _fun
@ -9281,7 +9281,7 @@ def custom_transpose(example_out):
return _custom_transpose(out_type, example_out)
return partial(
_custom_transpose,
tree_util.tree_map(
jax.tree.map(
lambda x: core.get_aval(x).at_least_vspace(), example_out))
@ -10139,13 +10139,13 @@ class CustomVmapTest(jtu.JaxTestCase):
f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs))
def test_tree(self):
tree_sin = partial(tree_util.tree_map, jnp.sin)
tree_cos = partial(tree_util.tree_map, jnp.cos)
tree_sin = partial(jax.tree.map, jnp.sin)
tree_cos = partial(jax.tree.map, jnp.cos)
x, xs = jnp.array(1.), jnp.arange(3)
x = (x, [x + 1, x + 2], [x + 3], x + 4)
xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4)
in_batched_ref = tree_util.tree_map(lambda _: True, x)
in_batched_ref = jax.tree.map(lambda _: True, x)
@jax.custom_batching.custom_vmap
def f(xs): return tree_sin(xs)
@ -10153,7 +10153,7 @@ class CustomVmapTest(jtu.JaxTestCase):
@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [in_batched_ref])
sz, = {z.shape[0] for z in tree_util.tree_leaves(xs)}
sz, = {z.shape[0] for z in jax.tree.leaves(xs)}
self.assertEqual(axis_size, sz)
return tree_cos(xs), in_batched[0]
@ -10163,13 +10163,13 @@ class CustomVmapTest(jtu.JaxTestCase):
self.assertAllClose(ys, tree_cos(xs))
def test_tree_with_nones(self):
tree_sin = partial(tree_util.tree_map, jnp.sin)
tree_cos = partial(tree_util.tree_map, jnp.cos)
tree_sin = partial(jax.tree.map, jnp.sin)
tree_cos = partial(jax.tree.map, jnp.cos)
x, xs = jnp.array(1.), jnp.arange(3)
x = (x, [x + 1, None], [x + 3], None)
xs = (xs, [xs + 1, None], [xs + 3], None)
in_batched_ref = tree_util.tree_map(lambda _: True, x)
in_batched_ref = jax.tree.map(lambda _: True, x)
@jax.custom_batching.custom_vmap
def f(xs): return tree_sin(xs)
@ -10177,7 +10177,7 @@ class CustomVmapTest(jtu.JaxTestCase):
@f.def_vmap
def rule(axis_size, in_batched, xs):
self.assertEqual(in_batched, [in_batched_ref])
sz, = {z.shape[0] for z in tree_util.tree_leaves(xs)}
sz, = {z.shape[0] for z in jax.tree.leaves(xs)}
self.assertEqual(axis_size, sz)
return tree_cos(xs), in_batched[0]

View File

@ -28,7 +28,6 @@ from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs
from jax import config
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce
from jax._src import core
from jax._src import linear_util as lu
@ -49,17 +48,17 @@ def call(f, *args):
@util.curry
def core_call(f, *args):
args, in_tree = tree_flatten(args)
args, in_tree = jax.tree.flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out = core.call_p.bind(f, *args)
return tree_unflatten(out_tree(), out)
return jax.tree.unflatten(out_tree(), out)
@util.curry
def core_closed_call(f, *args):
args, in_tree = tree_flatten(args)
args, in_tree = jax.tree.flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out = core.closed_call_p.bind(f, *args)
return tree_unflatten(out_tree(), out)
return jax.tree.unflatten(out_tree(), out)
def simple_fun(x, y):
return jnp.sin(x * y)
@ -175,24 +174,24 @@ class CoreTest(jtu.JaxTestCase):
zs = ({'a': 11}, [22, 33])
f = lambda x, y: x + y
assert tree_map(f, xs, ys) == zs
assert jax.tree.map(f, xs, ys) == zs
try:
tree_map(f, xs, ys_bad)
jax.tree.map(f, xs, ys_bad)
assert False
except (TypeError, ValueError):
pass
def test_tree_flatten(self):
flat, _ = tree_flatten(({'a': 1}, [2, 3], 4))
flat, _ = jax.tree.flatten(({'a': 1}, [2, 3], 4))
assert flat == [1, 2, 3, 4]
def test_tree_unflatten(self):
tree = [(1, 2), {"roy": (3, [4, 5, ()])}]
flat, treedef = tree_flatten(tree)
flat, treedef = jax.tree.flatten(tree)
assert flat == [1, 2, 3, 4, 5]
tree2 = tree_unflatten(treedef, flat)
nodes_equal = tree_map(operator.eq, tree, tree2)
assert tree_reduce(operator.and_, nodes_equal)
tree2 = jax.tree.unflatten(treedef, flat)
nodes_equal = jax.tree.map(operator.eq, tree, tree2)
assert jax.tree.reduce(operator.and_, nodes_equal)
@jtu.sample_product(
dtype=[*jtu.dtypes.all, object, [('i', 'i4'), ('f', 'f4')]]

View File

@ -25,7 +25,6 @@ import jax
from jax import lax
from jax.ad_checkpoint import checkpoint
from jax._src import test_util as jtu
from jax import tree_util
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
@ -160,7 +159,7 @@ class CustomLinearSolveTest(jtu.JaxTestCase):
# vmap test
c = rng.randn(3, 2)
expected = jnp.linalg.solve(a, c)
expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2), array_aux)
expected_aux = jax.tree.map(partial(np.repeat, repeats=2), array_aux)
actual_vmap, vmap_aux = jax.vmap(linear_solve_aux, (None, 1), -1)(a, c)
self.assertAllClose(expected, actual_vmap)
@ -473,7 +472,7 @@ class CustomLinearSolveTest(jtu.JaxTestCase):
return mv(b), aux
def solve_aux(x):
matvec = lambda y: tree_util.tree_map(partial(jnp.dot, A), y)
matvec = lambda y: jax.tree.map(partial(jnp.dot, A), y)
return lax.custom_linear_solve(matvec, (x, x), solve, solve, symmetric=True, has_aux=True)
rng = self.rng()

View File

@ -22,7 +22,6 @@ import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
from jax import tree_util
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
@ -227,7 +226,7 @@ class CustomRootTest(jtu.JaxTestCase):
expected_fwd_val = expected_fwd(a, b)
self.assertAllClose(fwd_val, expected_fwd_val, rtol={np.float32: 5E-6, np.float64: 5E-12})
jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))
jtu.check_close(fwd_aux, jax.tree.map(jnp.zeros_like, fwd_aux))
def test_custom_root_errors(self):
with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):

View File

@ -155,7 +155,7 @@ class PrimitiveTest(jtu.JaxTestCase):
if device.platform in skip_run_on_platforms:
logging.info("Skipping running on %s", device)
continue
device_args = jax.tree_util.tree_map(
device_args = jax.tree.map(
lambda x: jax.device_put(x, device), args
)
logging.info("Running harness natively on %s", device)

View File

@ -25,7 +25,6 @@ from absl.testing import absltest
import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax.experimental import export
from jax.experimental.export import _export
from jax.experimental import pjit
@ -185,7 +184,7 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual("my_fun", exp.fun_name)
self.assertEqual((export.default_lowering_platform(),),
exp.lowering_platforms)
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
@ -201,9 +200,9 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertEqual(exp.lowering_platforms, ("cpu",))
args = ((a, b),)
kwargs = dict(a=a, b=b)
self.assertEqual(exp.in_tree, tree_util.tree_flatten((args, kwargs))[1])
self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1])
self.assertEqual(exp.in_avals, (a_aval, b_aval, a_aval, b_aval))
self.assertEqual(exp.out_tree, tree_util.tree_flatten(f(*args, **kwargs))[1])
self.assertEqual(exp.out_tree, jax.tree.flatten(f(*args, **kwargs))[1])
self.assertEqual(exp.out_avals, (a_aval, b_aval, a_aval, b_aval, a_aval, b_aval))
def test_basic(self):

View File

@ -15,7 +15,6 @@
# Helpers for writing JAX filecheck tests.
import jax
import jax.tree_util as tree_util
import numpy as np
def print_ir(*prototypes):
@ -23,8 +22,8 @@ def print_ir(*prototypes):
"""Prints the MLIR IR that results from lowering `f`.
The arguments to `f` are taken to be arrays shaped like `prototypes`."""
inputs = tree_util.tree_map(np.array, prototypes)
flat_inputs, _ = tree_util.tree_flatten(inputs)
inputs = jax.tree.map(np.array, prototypes)
flat_inputs, _ = jax.tree.flatten(inputs)
shape_strs = " ".join([f"{x.dtype.name}[{','.join(map(str, x.shape))}]"
for x in flat_inputs])
name = f.func.__name__ if hasattr(f, "func") else f.__name__

View File

@ -34,7 +34,6 @@ from jax import config
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax.experimental import host_callback as hcb
from jax.experimental import pjit
from jax.sharding import PartitionSpec as P
@ -887,7 +886,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
# making the Jaxpr does not print anything
hcb.barrier_wait()
treedef = tree_util.tree_structure(arg)
treedef = jax.tree.structure(arg)
assertMultiLineStrippedEqual(
self, f"""
{{ lambda ; a:f32[]. let
@ -1027,7 +1026,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
return res
ct_dtype = core.primal_dtype_to_tangent_dtype(res_dtype)
return np.ones(np.shape(res), dtype=ct_dtype)
cts = tree_util.tree_map(make_ct, res_f_of_args)
cts = jax.tree.map(make_ct, res_f_of_args)
def f_vjp(args, cts):
res, pullback = jax.vjp(f, *args)
return pullback(cts)

View File

@ -72,7 +72,7 @@ class InfeedTest(jtu.JaxTestCase):
device = jax.local_devices()[0]
# We must transfer the flattened data, as a tuple!!!
flat_to_infeed, _ = jax.tree_util.tree_flatten(to_infeed)
flat_to_infeed, _ = jax.tree.flatten(to_infeed)
device.transfer_to_infeed(tuple(flat_to_infeed))
self.assertAllClose(f(x), to_infeed)

View File

@ -58,8 +58,8 @@ class JetTest(jtu.JaxTestCase):
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
# Convert to jax arrays to ensure dtype canonicalization.
primals = jax.tree_util.tree_map(jnp.asarray, primals)
series = jax.tree_util.tree_map(jnp.asarray, series)
primals = jax.tree.map(jnp.asarray, primals)
series = jax.tree.map(jnp.asarray, series)
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
@ -73,8 +73,8 @@ class JetTest(jtu.JaxTestCase):
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
# Convert to jax arrays to ensure dtype canonicalization.
primals = jax.tree_util.tree_map(jnp.asarray, primals)
series = jax.tree_util.tree_map(jnp.asarray, series)
primals = jax.tree.map(jnp.asarray, primals)
series = jax.tree.map(jnp.asarray, series)
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)

View File

@ -289,8 +289,8 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def testWhileTypeErrors(self):
"""Test typing error messages for while."""
tuple_treedef = tree_util.tree_structure((1., 1.))
leaf_treedef = tree_util.tree_structure(0.)
tuple_treedef = jax.tree.structure((1., 1.))
leaf_treedef = jax.tree.structure(0.)
with self.assertRaisesRegex(TypeError,
re.escape(f"cond_fun must return a boolean scalar, but got pytree {tuple_treedef}.")):
lax.while_loop(lambda c: (1., 1.), lambda c: c, 0.)
@ -970,7 +970,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
with self.assertRaisesRegex(TypeError,
re.escape("true_fun and false_fun output must have same type structure, "
f"got {tree_util.tree_structure(2.)} and {tree_util.tree_structure((3., 3.))}.")):
f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
lax.cond(True, lambda top: 2., lambda fop: (3., 3.), 1.)
with self.assertRaisesRegex(
TypeError,
@ -998,7 +998,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.switch(0, [], 1.)
with self.assertRaisesRegex(TypeError,
re.escape("branch 0 and 1 outputs must have same type structure, "
f"got {tree_util.tree_structure(2.)} and {tree_util.tree_structure((3., 3.))}.")):
f"got {jax.tree.structure(2.)} and {jax.tree.structure((3., 3.))}.")):
lax.switch(1, [lambda _: 2., lambda _: (3., 3.)], 1.)
with self.assertRaisesRegex(
TypeError,

View File

@ -42,7 +42,6 @@ import jax.ops
from jax import lax
from jax import numpy as jnp
from jax.sharding import SingleDeviceSharding
from jax import tree_util
from jax.test_util import check_grads
from jax._src import array
@ -971,7 +970,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
)
def testPad(self, shape, dtype, mode, pad_width, constant_values):
if np.issubdtype(dtype, np.unsignedinteger):
constant_values = tree_util.tree_map(abs, constant_values)
constant_values = jax.tree.map(abs, constant_values)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if constant_values is None:

View File

@ -29,7 +29,6 @@ import jax.dtypes
from jax import numpy as jnp
from jax import lax
from jax import scipy as jsp
from jax.tree_util import tree_map
from jax._src.scipy import special as lsp_special_internal
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
@ -125,7 +124,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
return_sign=return_sign, b=scale_array)
if dtype == np.int32:
res = tree_map(lambda x: x.astype('float32'), res)
res = jax.tree.map(lambda x: x.astype('float32'), res)
return res
def lax_fun(array_to_reduce, scale_array):
@ -138,7 +137,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
return_sign=return_sign)
if dtype == np.int32:
res = tree_map(lambda x: x.astype('float32'), res)
res = jax.tree.map(lambda x: x.astype('float32'), res)
return res
def lax_fun(array_to_reduce):

View File

@ -31,7 +31,6 @@ from jax._src import core
from jax import lax
import jax.numpy as jnp
from jax.test_util import check_grads
from jax import tree_util
import jax.util
from jax.interpreters import batching
@ -2586,7 +2585,7 @@ class LaxTest(jtu.JaxTestCase):
operands = {'x': [np.ones(5), np.arange(5)]}
init_values = {'x': [0., 0]}
result = lax.reduce(operands, init_values,
lambda x, y: tree_util.tree_map(lax.add, x, y),
lambda x, y: jax.tree.map(lax.add, x, y),
[0])
self.assertDictEqual(result, {'x': [5., 10]})

View File

@ -21,7 +21,6 @@ import jax
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.experimental.ode import odeint
from jax.tree_util import tree_map
import scipy.integrate as osp_integrate
@ -63,7 +62,7 @@ class ODETest(jtu.JaxTestCase):
def test_pytree_state(self):
"""Test calling odeint with y(t) values that are pytrees."""
def dynamics(y, _t):
return tree_map(jnp.negative, y)
return jax.tree.map(jnp.negative, y)
y0 = (np.array(-0.1), np.array([[[0.1]]]))
ts = np.linspace(0., 1., 11)

View File

@ -19,10 +19,10 @@ import functools
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from jax import jit, grad, jacfwd, jacrev
from jax import tree_util
from jax import lax
from jax.example_libraries import optimizers
@ -41,8 +41,8 @@ class OptimizerTests(jtu.JaxTestCase):
opt_state = init_fun(x0)
self.assertAllClose(x0, get_params(opt_state))
opt_state2 = update_fun(0, grad(loss)(x0), opt_state) # doesn't crash
self.assertEqual(tree_util.tree_structure(opt_state),
tree_util.tree_structure(opt_state2))
self.assertEqual(jax.tree.structure(opt_state),
jax.tree.structure(opt_state2))
@jtu.skip_on_devices('gpu')
def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs):

View File

@ -24,7 +24,6 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import random
from jax import tree_util
from jax._src import test_util as jtu
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
@ -479,7 +478,7 @@ class SplashAttentionTest(AttentionTest):
custom_type="vanilla",
attn_logits_soft_cap=attn_logits_soft_cap)
o_ref, attn_vjp_ref = jax.vjp(attn_ref, q, k, v, segment_ids)
q32, k32, v32 = tree_util.tree_map(lambda x: x.astype(jnp.float32),
q32, k32, v32 = jax.tree.map(lambda x: x.astype(jnp.float32),
(q, k, v))
o_custom = attn_custom(q32, k32, v32, segment_ids)
_, attn_vjp = jax.vjp(attn_custom, q32, k32, v32, segment_ids)
@ -582,7 +581,7 @@ class SplashAttentionTest(AttentionTest):
attn_logits_soft_cap=attn_logits_soft_cap,
)
o, attn_vjp = jax.vjp(attn, q, k, v, segment_ids)
q32, k32, v32 = tree_util.tree_map(
q32, k32, v32 = jax.tree.map(
lambda x: x.astype(jnp.float32), (q, k, v)
)
o_ref, (logsumexp,) = attn_ref(

View File

@ -1000,7 +1000,7 @@ class PJitTest(jtu.BufferDonationTestCase):
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0, 0), {}))[1])
self.assertEqual(obj.in_tree, jax.tree.flatten(((0, 0), {}))[1])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testLowerCompileWithKwargs(self):
@ -1799,7 +1799,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def f(tree):
return tree
out_tree = f((a1 @ a1.T, (a2, (a3 * 2, a4))))
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
(out1, out2, out3, out4), _ = jax.tree.flatten(out_tree)
self.assertIsInstance(out1, array.ArrayImpl)
self.assertEqual(out1.shape, (8, 8))
@ -4281,7 +4281,7 @@ class UtilTest(jtu.JaxTestCase):
("mix_4", (UNSPECIFIED, P('x'), UNSPECIFIED), ValueError),
)
def test_all_or_non_unspecified(self, axis_resources, error=None):
entries, _ = jax.tree_util.tree_flatten(axis_resources, is_leaf=lambda x: x is None)
entries, _ = jax.tree.flatten(axis_resources, is_leaf=lambda x: x is None)
if error is not None:
with self.assertRaises(error):
sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources')

View File

@ -36,7 +36,6 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax import lax
from jax import random
from jax import tree_util
from jax.ad_checkpoint import checkpoint as new_checkpoint
import jax.numpy as jnp
from jax._src import api as src_api
@ -205,7 +204,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# It's a pair of: (positional args, as a tuple of their structures, kwargs).
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_tree, jax.tree.flatten(((0,), {}))[1])
self.assertEqual(obj.in_avals, ((core.ShapedArray(x.shape, x.dtype),), {}))
def testLowerCompileInTreeMismatch(self):
@ -524,7 +523,7 @@ class PythonPmapTest(jtu.JaxTestCase):
n = lax.psum(1, axis_name)
return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)])
tree_f = lambda f: partial(tree_util.tree_map, f)
tree_f = lambda f: partial(jax.tree.map, f)
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
np_transpose = tree_f(np.transpose)
@ -535,7 +534,7 @@ class PythonPmapTest(jtu.JaxTestCase):
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
assert_allclose = partial(tree_util.tree_map,
assert_allclose = partial(jax.tree.map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
@ -550,10 +549,10 @@ class PythonPmapTest(jtu.JaxTestCase):
'b': np.arange(2 * n * n, 3 * n * n, dtype=np.int32).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n, dtype=np.float32).reshape([n, n]),
'd': np.arange(6 * n * n, 7 * n * n, dtype=np.int32).reshape([n, n])}
tree_f = lambda f: partial(tree_util.tree_map, f)
tree_f = lambda f: partial(jax.tree.map, f)
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
assert_allclose = partial(tree_util.tree_map,
assert_allclose = partial(jax.tree.map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
@ -1488,7 +1487,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@vmap
def s(keys):
keys = tree_util.tree_map(
keys = jax.tree.map(
lambda x: jnp.broadcast_to(x, (N_DEVICES,) + x.shape),
keys)
return g(keys)
@ -2677,7 +2676,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
return {'a': x}
device_count = jax.device_count()
x = jnp.arange(device_count)
tree_util.tree_map(self.assertAllClose, f(x), {'a': x})
jax.tree.map(self.assertAllClose, f(x), {'a': x})
@jtu.sample_product(
in_axes=all_bdims((3, 4), (3, 1), (1, 4), pmap=True),

View File

@ -31,7 +31,6 @@ import jax
from jax import lax
from jax import numpy as jnp
from jax import random
from jax import tree_util
from jax._src import config
from jax._src import core
from jax._src import deprecations
@ -1114,7 +1113,7 @@ class KeyArrayTest(jtu.JaxTestCase):
def f(_, state):
return state
def _f_fwd(_, state):
return tree_util.tree_map(lambda x: x.value, state), None
return jax.tree.map(lambda x: x.value, state), None
def _f_bwd(_, state_bar):
self.assertTrue(state_bar[1].dtype == dtypes.float0)
self.assertIsInstance(state_bar[1], jax.custom_derivatives.SymbolicZero)
@ -1173,12 +1172,12 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
like = lambda keys: jnp.ones(keys.shape)
out_key = func(*args)
self.assertIsInstance(out_key, prng_internal.PRNGKeyArray)
out_like_key = func(*tree_util.tree_map(like, args))
out_like_key = func(*jax.tree.map(like, args))
self.assertIsInstance(out_like_key, jax.Array)
self.assertEqual(out_key.shape, out_like_key.shape)
def check_against_reference(self, key_func, arr_func, *key_args):
out_arr = arr_func(*tree_util.tree_map(lambda x: random.key_data(x),
out_arr = arr_func(*jax.tree.map(lambda x: random.key_data(x),
key_args))
self.assertIsInstance(out_arr, jax.Array)

View File

@ -23,7 +23,7 @@ import scipy.stats as osp_stats
import scipy.version
import jax
from jax._src import dtypes, test_util as jtu, tree_util
from jax._src import dtypes, test_util as jtu
from jax.scipy import stats as lsp_stats
from jax.scipy.special import expit
@ -1517,9 +1517,9 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
dataset = rng((3, 15), dtype)
x = rng((3, 12), dtype)
kde = lsp_stats.gaussian_kde(dataset)
leaves, treedef = tree_util.tree_flatten(kde)
kde2 = tree_util.tree_unflatten(treedef, leaves)
tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
leaves, treedef = jax.tree.flatten(kde)
kde2 = jax.tree.unflatten(treedef, leaves)
jax.tree.map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
@jtu.sample_product(

View File

@ -46,7 +46,6 @@ from jax import random
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import tree_util
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
@ -1270,7 +1269,7 @@ class PolyHarness(Harness):
f_jax = self.dyn_fun
args = self.dyn_args_maker(tst.rng())
args = tree_util.tree_map(jnp.array, args)
args = jax.tree.map(jnp.array, args)
args_specs = export.symbolic_args_specs(args, self.polymorphic_shapes,
symbolic_constraints=self.symbolic_constraints)

View File

@ -1613,7 +1613,7 @@ def shmap_reference(
args_shards = [x[indexer(idx)] for x, indexer in zip(args, getters)]
assert all(x.shape == r.shape for x, r in zip(args_shards, body_in_types))
out_shards = f(*args_shards)
assert jax.tree_util.tree_all(jax.tree.map(lambda y, r: y.shape == r.shape,
assert jax.tree.all(jax.tree.map(lambda y, r: y.shape == r.shape,
out_shards, body_out_types))
outs = jax.tree.map(lambda y, out, indexer: out.at[indexer(idx)].set(y),
out_shards, outs, putters)
@ -1863,12 +1863,12 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
args_slice = args_slicer(args, bdims)
expected_slices = [f(*args_slice(i)) for i in range(5)]
treedef = tree_util.tree_structure(ans)
treedef = jax.tree.structure(ans)
if tree_util.treedef_is_strict_leaf(treedef):
expected = jnp.stack(expected_slices)
else:
slices = map(jnp.stack, zip(*expected_slices))
expected = tree_util.tree_unflatten(treedef, slices)
expected = jax.tree.unflatten(treedef, slices)
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
@ -1902,12 +1902,12 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
args_slice = args_slicer((xs, *closed_over_args), (0, *closed_over_bdims))
expected_slices = [f(*args_slice(i)) for i in range(5)]
treedef = tree_util.tree_structure(ans)
treedef = jax.tree.structure(ans)
if tree_util.treedef_is_strict_leaf(treedef):
expected = jnp.stack(expected_slices)
else:
slices = map(jnp.stack, zip(*expected_slices))
expected = tree_util.tree_unflatten(treedef, slices)
expected = jax.tree.unflatten(treedef, slices)
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)

View File

@ -36,7 +36,6 @@ from jax.experimental.sparse import _lowerings
from jax._src import xla_bridge
from jax._src.lib import gpu_sparse
from jax import jit
from jax import tree_util
from jax import vmap
from jax._src import test_util as jtu
from jax.interpreters import mlir
@ -735,9 +734,9 @@ class SparseObjectTest(sptu.SparseTestCase):
sparse_format = cls.__name__.lower()
M = sparse.empty((2, 4), sparse_format=sparse_format)
self.assertIsInstance(M, cls)
buffers, tree = tree_util.tree_flatten(M)
buffers, tree = jax.tree.flatten(M)
self.assertTrue(all(isinstance(buffer, jax.Array) for buffer in buffers))
M_out = tree_util.tree_unflatten(tree, buffers)
M_out = jax.tree.unflatten(tree, buffers)
self.assertEqual(M.dtype, M_out.dtype)
self.assertEqual(M.shape, M_out.shape)
self.assertEqual(M.nse, M_out.nse)
@ -877,7 +876,7 @@ class SparseObjectTest(sptu.SparseTestCase):
def test_todense_ad(self, Obj, shape=(3,), dtype=np.float32):
M_dense = jnp.array([1., 2., 3.])
M = M_dense if Obj is jnp.array else Obj.fromdense(M_dense)
bufs, tree = tree_util.tree_flatten(M)
bufs, tree = jax.tree.flatten(M)
jac = jnp.eye(M.shape[0], dtype=M.dtype)
jac1 = jax.jacfwd(lambda *bufs: sparse.todense_p.bind(*bufs, tree=tree))(*bufs)
jac2 = jax.jacrev(lambda *bufs: sparse.todense_p.bind(*bufs, tree=tree))(*bufs)