From cd6e012326963ffca4c69552349ca9a6b08af800 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 13 May 2024 14:36:43 -0700 Subject: [PATCH] Enable JAX memory tests for GPUs and CPUs PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now. PiperOrigin-RevId: 633335638 --- tests/memories_test.py | 43 ++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index f53967d37..5bc6e27db 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -24,6 +24,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src import config +from jax._src.lib import xla_extension_version from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp from jax.sharding import PartitionSpec as P @@ -64,7 +65,7 @@ def _create_inputs(shape, pspec, mem_kind=None): class ShardingMemoriesTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(["tpu"]): + if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]): self.skipTest("Memories do not work on CPU and GPU backends yet.") # TODO(b/311021572) if jtu.is_cloud_tpu(): @@ -72,6 +73,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase): super().setUp() self.orig_memories_flag = config.enable_memories.value jax.config.update('jax_enable_memories', True) + if jtu.test_device_matches(["cpu"]): + self._default_memory_kind = "unpinned_host" + else: + self._default_memory_kind = "device" def tearDown(self): jax.config.update('jax_enable_memories', self.orig_memories_flag) @@ -87,17 +92,17 @@ class ShardingMemoriesTest(jtu.JaxTestCase): if name == "named_sharding": mesh = jtu.create_global_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) - self.assertEqual(ns.memory_kind, "device") + self.assertEqual(ns.memory_kind, self._default_memory_kind) elif name == "positional_sharding": ps = PositionalSharding(jax.devices()) - self.assertEqual(ps.memory_kind, "device") + self.assertEqual(ps.memory_kind, self._default_memory_kind) elif name == "single_device_sharding": ss = SingleDeviceSharding(jax.devices()[0]) - self.assertEqual(ss.memory_kind, "device") + self.assertEqual(ss.memory_kind, self._default_memory_kind) else: assert name == "gspmd_sharding" gs = GSPMDSharding.get_replicated(jax.devices()) - self.assertEqual(gs.memory_kind, "device") + self.assertEqual(gs.memory_kind, self._default_memory_kind) @parameterized.named_parameters( ("named_sharding", "named_sharding"), @@ -108,26 +113,26 @@ class ShardingMemoriesTest(jtu.JaxTestCase): def test_wrong_memory_kind(self, name): if name == "named_sharding": with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): mesh = jtu.create_global_mesh((8,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") elif name == "positional_sharding": with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): PositionalSharding(jax.devices(), memory_kind="gpu_hbm") elif name == "single_device_sharding": with self.assertRaisesRegex( ValueError, - "Could not find memory addressable by device TPU.*Device TPU.*" + "Could not find memory addressable by device.*Device.*" " can address the following memory kinds.*", ): SingleDeviceSharding(jax.devices()[0], memory_kind="host") else: assert name == "gspmd_sharding" with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host") @@ -138,11 +143,13 @@ class ShardingMemoriesTest(jtu.JaxTestCase): ("gspmd_sharding", "gspmd_sharding"), ) def test_correct_tpu_memory_kind(self, name): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("TPU memory kind test.") if name == "named_sharding": mesh = jtu.create_global_mesh((8,), ("x",)) - NamedSharding(mesh, P("x"), memory_kind="device") + NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) elif name == "positional_sharding": - PositionalSharding(jax.devices(), memory_kind="device") + PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) elif name == "single_device_sharding": SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") else: @@ -159,19 +166,19 @@ class ShardingMemoriesTest(jtu.JaxTestCase): if name == "named_sharding": mesh = jtu.create_global_mesh((8,), ("x",)) s1 = NamedSharding(mesh, P("x")) - s2 = NamedSharding(mesh, P("x"), memory_kind="device") + s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "positional_sharding": s1 = PositionalSharding(jax.devices()) - s2 = PositionalSharding(jax.devices(), memory_kind="device") + s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "single_device_sharding": s1 = SingleDeviceSharding(jax.devices()[0]) - s2 = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "gspmd_sharding": s1 = GSPMDSharding.get_replicated(jax.devices()) - s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind="device") + s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) def test_sharding_equivalent(self): @@ -181,11 +188,11 @@ class ShardingMemoriesTest(jtu.JaxTestCase): gs1 = GSPMDSharding( tuple(mesh.devices.flat), ns1._to_xla_hlo_sharding(ndim), - memory_kind="device", + memory_kind=self._default_memory_kind, ) self.assertTrue(ns1.is_equivalent_to(gs1, ndim)) - ns2 = NamedSharding(mesh, P("x"), memory_kind="device") + ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) gs2 = GSPMDSharding( tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim) ) @@ -193,7 +200,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase): def test_default_memory_kind(self): dev = jax.devices()[0] - self.assertEqual(dev.default_memory().kind, "device") + self.assertEqual(dev.default_memory().kind, self._default_memory_kind) class MemoriesComputationTest(jtu.BufferDonationTestCase):