Avoid top-level aliases of jax.tree_util.*

This commit is contained in:
Jake VanderPlas 2022-07-07 11:41:02 -07:00
parent 2314951669
commit a10f0377db
10 changed files with 25 additions and 25 deletions

View File

@ -319,7 +319,7 @@ def pmap_simple_8_devices_100_args(state):
while state:
out = f(*args)
jax.tree_map(lambda x: x.block_until_ready(), out)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
def _run_sda_index_bench(state, num_devices):

View File

@ -163,7 +163,7 @@ async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
def run_serialization(gdas, tensorstore_specs):
async def _run_serializer():
future_writer = jax.tree_map(async_serialize, gdas, tensorstore_specs)
future_writer = jax.tree_util.tree_map(async_serialize, gdas, tensorstore_specs)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
@ -235,7 +235,7 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
# Object should be created once per process.
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
future_gdas = jax.tree_map(
future_gdas = jax.tree_util.tree_map(
partial(async_deserialize, byte_limiter=byte_limiter),
global_meshes, mesh_axes, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
@ -427,13 +427,13 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
commit_futures = [[] for _ in range(len(tensorstore_specs))]
async def _run_serializer():
future_writer = jax.tree_map(async_serialize, gdas, tensorstore_specs,
future_writer = jax.tree_util.tree_map(async_serialize, gdas, tensorstore_specs,
commit_futures)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
self._add_futures(jax.tree_flatten(commit_futures)[0])
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
# Used in wait_until_finished to check on process != 0, if the checkpoint
# has finished writing.

View File

@ -60,7 +60,7 @@ class CheckpointTest(jtu.JaxTestCase):
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1, gda2, gda3], tspecs)
@ -103,7 +103,7 @@ class CheckpointTest(jtu.JaxTestCase):
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir1)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1], tspecs)
@ -138,7 +138,7 @@ class CheckpointTest(jtu.JaxTestCase):
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir1)]
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
serialization.run_serialization([gda1], tspecs)

View File

@ -690,9 +690,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertNotEqual(type(m.a), list)
self.assertNotEqual(type(m.b), tuple)
self.assertNotEqual(type(m.c), dict)
self.assertLen(jax.tree_leaves(m.a), 2)
self.assertLen(jax.tree_leaves(m.b), 2)
self.assertLen(jax.tree_leaves(m.c), 2)
self.assertLen(jax.tree_util.tree_leaves(m.a), 2)
self.assertLen(jax.tree_util.tree_leaves(m.b), 2)
self.assertLen(jax.tree_util.tree_leaves(m.c), 2)
def test_issue_10586(self):

View File

@ -67,9 +67,9 @@ def broadcast_one_to_all(in_tree: PyTreeDef,
def post_pmap(x):
return jax.device_get(x)[0]
in_tree = jax.tree_map(pre_pmap, in_tree)
in_tree = jax.tree_util.tree_map(pre_pmap, in_tree)
in_tree = jax.device_get(_psum(in_tree))
return jax.tree_map(post_pmap, in_tree)
return jax.tree_util.tree_map(post_pmap, in_tree)
def sync_global_devices(name: str):
@ -121,14 +121,14 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
return out.local_data(0).to_py()
with jax._src.config.parallel_functions_output_gda(True):
return jax.tree_map(_pjit, in_tree)
return jax.tree_util.tree_map(_pjit, in_tree)
def assert_equal(in_tree, fail_message: str = ''):
"""Verifies that all the hosts have the same tree of values."""
expected = broadcast_one_to_all(in_tree)
if not jax.tree_util.tree_all(
jax.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
jax.tree_util.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
raise AssertionError(
f'{fail_message} Expected: {expected}; got: {in_tree}.')

View File

@ -931,7 +931,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertEqual(
obj.in_avals,
((jax.ShapedArray([], expected_dtype, weak_type=True),), {}))
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
def test_jit_lower_duck_typing(self):
f_jit = self.jit(lambda x: 2 * x)

View File

@ -64,7 +64,7 @@ class InfeedTest(jtu.JaxTestCase):
device = jax.local_devices()[0]
# We must transfer the flattened data, as a tuple!!!
flat_to_infeed, _ = jax.tree_flatten(to_infeed)
flat_to_infeed, _ = jax.tree_util.tree_flatten(to_infeed)
device.transfer_to_infeed(tuple(flat_to_infeed))
self.assertAllClose(f(x), to_infeed)

View File

@ -58,8 +58,8 @@ class JetTest(jtu.JaxTestCase):
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
# Convert to jax arrays to ensure dtype canonicalization.
primals = jax.tree_map(jnp.asarray, primals)
series = jax.tree_map(jnp.asarray, series)
primals = jax.tree_util.tree_map(jnp.asarray, primals)
series = jax.tree_util.tree_map(jnp.asarray, series)
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)
@ -73,8 +73,8 @@ class JetTest(jtu.JaxTestCase):
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5,
check_dtypes=True):
# Convert to jax arrays to ensure dtype canonicalization.
primals = jax.tree_map(jnp.asarray, primals)
series = jax.tree_map(jnp.asarray, series)
primals = jax.tree_util.tree_map(jnp.asarray, primals)
series = jax.tree_util.tree_map(jnp.asarray, series)
y, terms = jet(fun, primals, series)
expected_y, expected_terms = jvp_taylor(fun, primals, series)

View File

@ -724,7 +724,7 @@ class PJitTest(jtu.BufferDonationTestCase):
for obj in [lowered, compiled]:
self.assertTrue(obj._no_kwargs, True)
self.assertEqual(obj.in_tree, jax.tree_flatten(((0, 0), {}))[1])
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0, 0), {}))[1])
@jtu.with_mesh([('x', 2), ('y', 2)])
def testLowerCompileWithKwargs(self):
@ -1455,7 +1455,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def f(tree):
return tree
out_tree = f((a1, (a2, (a3, a4))))
(out1, out2, out3, out4), _ = jax.tree_flatten(out_tree)
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
self.assertIsInstance(out1, array.Array)
self.assertEqual(out1.shape, (8, 2))
@ -1822,7 +1822,7 @@ class UtilTest(jtu.JaxTestCase):
("mix_4", (pjit_lib._UNSPECIFIED, P('x'), pjit_lib._UNSPECIFIED), ValueError),
)
def test_all_or_non_unspecified(self, axis_resources, error=None):
entries, _ = jax.tree_flatten(axis_resources, is_leaf=lambda x: x is None)
entries, _ = jax.tree_util.tree_flatten(axis_resources, is_leaf=lambda x: x is None)
if error is not None:
with self.assertRaises(error):
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')

View File

@ -195,7 +195,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# It's a pair of: (positional args, as a tuple of their structures, kwargs).
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_avals, ((jax.ShapedArray(x.shape, x.dtype),), {}))
def testLowerCompileInTreeMismatch(self):