mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Avoid top-level aliases of jax.tree_util.*
This commit is contained in:
parent
2314951669
commit
a10f0377db
@ -319,7 +319,7 @@ def pmap_simple_8_devices_100_args(state):
|
||||
|
||||
while state:
|
||||
out = f(*args)
|
||||
jax.tree_map(lambda x: x.block_until_ready(), out)
|
||||
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
|
||||
|
||||
|
||||
def _run_sda_index_bench(state, num_devices):
|
||||
|
@ -163,7 +163,7 @@ async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
|
||||
|
||||
def run_serialization(gdas, tensorstore_specs):
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_map(async_serialize, gdas, tensorstore_specs)
|
||||
future_writer = jax.tree_util.tree_map(async_serialize, gdas, tensorstore_specs)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
@ -235,7 +235,7 @@ def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
|
||||
# Object should be created once per process.
|
||||
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
|
||||
|
||||
future_gdas = jax.tree_map(
|
||||
future_gdas = jax.tree_util.tree_map(
|
||||
partial(async_deserialize, byte_limiter=byte_limiter),
|
||||
global_meshes, mesh_axes, tensorstore_specs,
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
@ -427,13 +427,13 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_map(async_serialize, gdas, tensorstore_specs,
|
||||
future_writer = jax.tree_util.tree_map(async_serialize, gdas, tensorstore_specs,
|
||||
commit_futures)
|
||||
return await asyncio.gather(*future_writer)
|
||||
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._add_futures(jax.tree_flatten(commit_futures)[0])
|
||||
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
|
||||
|
||||
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
||||
# has finished writing.
|
||||
|
@ -60,7 +60,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1, gda2, gda3], tspecs)
|
||||
|
||||
@ -103,7 +103,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
@ -138,7 +138,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
||||
|
||||
serialization.run_serialization([gda1], tspecs)
|
||||
|
||||
|
@ -690,9 +690,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertNotEqual(type(m.a), list)
|
||||
self.assertNotEqual(type(m.b), tuple)
|
||||
self.assertNotEqual(type(m.c), dict)
|
||||
self.assertLen(jax.tree_leaves(m.a), 2)
|
||||
self.assertLen(jax.tree_leaves(m.b), 2)
|
||||
self.assertLen(jax.tree_leaves(m.c), 2)
|
||||
self.assertLen(jax.tree_util.tree_leaves(m.a), 2)
|
||||
self.assertLen(jax.tree_util.tree_leaves(m.b), 2)
|
||||
self.assertLen(jax.tree_util.tree_leaves(m.c), 2)
|
||||
|
||||
def test_issue_10586(self):
|
||||
|
||||
|
@ -67,9 +67,9 @@ def broadcast_one_to_all(in_tree: PyTreeDef,
|
||||
def post_pmap(x):
|
||||
return jax.device_get(x)[0]
|
||||
|
||||
in_tree = jax.tree_map(pre_pmap, in_tree)
|
||||
in_tree = jax.tree_util.tree_map(pre_pmap, in_tree)
|
||||
in_tree = jax.device_get(_psum(in_tree))
|
||||
return jax.tree_map(post_pmap, in_tree)
|
||||
return jax.tree_util.tree_map(post_pmap, in_tree)
|
||||
|
||||
|
||||
def sync_global_devices(name: str):
|
||||
@ -121,14 +121,14 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
|
||||
return out.local_data(0).to_py()
|
||||
|
||||
with jax._src.config.parallel_functions_output_gda(True):
|
||||
return jax.tree_map(_pjit, in_tree)
|
||||
return jax.tree_util.tree_map(_pjit, in_tree)
|
||||
|
||||
|
||||
def assert_equal(in_tree, fail_message: str = ''):
|
||||
"""Verifies that all the hosts have the same tree of values."""
|
||||
expected = broadcast_one_to_all(in_tree)
|
||||
if not jax.tree_util.tree_all(
|
||||
jax.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
|
||||
jax.tree_util.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
|
||||
raise AssertionError(
|
||||
f'{fail_message} Expected: {expected}; got: {in_tree}.')
|
||||
|
||||
|
@ -931,7 +931,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(
|
||||
obj.in_avals,
|
||||
((jax.ShapedArray([], expected_dtype, weak_type=True),), {}))
|
||||
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])
|
||||
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
|
||||
|
||||
def test_jit_lower_duck_typing(self):
|
||||
f_jit = self.jit(lambda x: 2 * x)
|
||||
|
@ -64,7 +64,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
|
||||
device = jax.local_devices()[0]
|
||||
# We must transfer the flattened data, as a tuple!!!
|
||||
flat_to_infeed, _ = jax.tree_flatten(to_infeed)
|
||||
flat_to_infeed, _ = jax.tree_util.tree_flatten(to_infeed)
|
||||
device.transfer_to_infeed(tuple(flat_to_infeed))
|
||||
self.assertAllClose(f(x), to_infeed)
|
||||
|
||||
|
@ -58,8 +58,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
|
||||
check_dtypes=True):
|
||||
# Convert to jax arrays to ensure dtype canonicalization.
|
||||
primals = jax.tree_map(jnp.asarray, primals)
|
||||
series = jax.tree_map(jnp.asarray, series)
|
||||
primals = jax.tree_util.tree_map(jnp.asarray, primals)
|
||||
series = jax.tree_util.tree_map(jnp.asarray, series)
|
||||
|
||||
y, terms = jet(fun, primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(fun, primals, series)
|
||||
@ -73,8 +73,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5,
|
||||
check_dtypes=True):
|
||||
# Convert to jax arrays to ensure dtype canonicalization.
|
||||
primals = jax.tree_map(jnp.asarray, primals)
|
||||
series = jax.tree_map(jnp.asarray, series)
|
||||
primals = jax.tree_util.tree_map(jnp.asarray, primals)
|
||||
series = jax.tree_util.tree_map(jnp.asarray, series)
|
||||
|
||||
y, terms = jet(fun, primals, series)
|
||||
expected_y, expected_terms = jvp_taylor(fun, primals, series)
|
||||
|
@ -724,7 +724,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
for obj in [lowered, compiled]:
|
||||
self.assertTrue(obj._no_kwargs, True)
|
||||
self.assertEqual(obj.in_tree, jax.tree_flatten(((0, 0), {}))[1])
|
||||
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0, 0), {}))[1])
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testLowerCompileWithKwargs(self):
|
||||
@ -1455,7 +1455,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def f(tree):
|
||||
return tree
|
||||
out_tree = f((a1, (a2, (a3, a4))))
|
||||
(out1, out2, out3, out4), _ = jax.tree_flatten(out_tree)
|
||||
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
|
||||
|
||||
self.assertIsInstance(out1, array.Array)
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
@ -1822,7 +1822,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
("mix_4", (pjit_lib._UNSPECIFIED, P('x'), pjit_lib._UNSPECIFIED), ValueError),
|
||||
)
|
||||
def test_all_or_non_unspecified(self, axis_resources, error=None):
|
||||
entries, _ = jax.tree_flatten(axis_resources, is_leaf=lambda x: x is None)
|
||||
entries, _ = jax.tree_util.tree_flatten(axis_resources, is_leaf=lambda x: x is None)
|
||||
if error is not None:
|
||||
with self.assertRaises(error):
|
||||
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')
|
||||
|
@ -195,7 +195,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# It's a pair of: (positional args, as a tuple of their structures, kwargs).
|
||||
for obj in [lowered, compiled]:
|
||||
self.assertFalse(obj._no_kwargs)
|
||||
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])
|
||||
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
|
||||
self.assertEqual(obj.in_avals, ((jax.ShapedArray(x.shape, x.dtype),), {}))
|
||||
|
||||
def testLowerCompileInTreeMismatch(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user