diff --git a/tests/memories_test.py b/tests/memories_test.py index f53967d37..5bc6e27db 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -24,6 +24,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src import config +from jax._src.lib import xla_extension_version from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp from jax.sharding import PartitionSpec as P @@ -64,7 +65,7 @@ def _create_inputs(shape, pspec, mem_kind=None): class ShardingMemoriesTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(["tpu"]): + if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]): self.skipTest("Memories do not work on CPU and GPU backends yet.") # TODO(b/311021572) if jtu.is_cloud_tpu(): @@ -72,6 +73,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase): super().setUp() self.orig_memories_flag = config.enable_memories.value jax.config.update('jax_enable_memories', True) + if jtu.test_device_matches(["cpu"]): + self._default_memory_kind = "unpinned_host" + else: + self._default_memory_kind = "device" def tearDown(self): jax.config.update('jax_enable_memories', self.orig_memories_flag) @@ -87,17 +92,17 @@ class ShardingMemoriesTest(jtu.JaxTestCase): if name == "named_sharding": mesh = jtu.create_global_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) - self.assertEqual(ns.memory_kind, "device") + self.assertEqual(ns.memory_kind, self._default_memory_kind) elif name == "positional_sharding": ps = PositionalSharding(jax.devices()) - self.assertEqual(ps.memory_kind, "device") + self.assertEqual(ps.memory_kind, self._default_memory_kind) elif name == "single_device_sharding": ss = SingleDeviceSharding(jax.devices()[0]) - self.assertEqual(ss.memory_kind, "device") + self.assertEqual(ss.memory_kind, self._default_memory_kind) else: assert name == "gspmd_sharding" gs = GSPMDSharding.get_replicated(jax.devices()) - self.assertEqual(gs.memory_kind, "device") + self.assertEqual(gs.memory_kind, self._default_memory_kind) @parameterized.named_parameters( ("named_sharding", "named_sharding"), @@ -108,26 +113,26 @@ class ShardingMemoriesTest(jtu.JaxTestCase): def test_wrong_memory_kind(self, name): if name == "named_sharding": with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): mesh = jtu.create_global_mesh((8,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") elif name == "positional_sharding": with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): PositionalSharding(jax.devices(), memory_kind="gpu_hbm") elif name == "single_device_sharding": with self.assertRaisesRegex( ValueError, - "Could not find memory addressable by device TPU.*Device TPU.*" + "Could not find memory addressable by device.*Device.*" " can address the following memory kinds.*", ): SingleDeviceSharding(jax.devices()[0], memory_kind="host") else: assert name == "gspmd_sharding" with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device TPU.*" + ValueError, "Could not find memory addressable by device.*" ): GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host") @@ -138,11 +143,13 @@ class ShardingMemoriesTest(jtu.JaxTestCase): ("gspmd_sharding", "gspmd_sharding"), ) def test_correct_tpu_memory_kind(self, name): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("TPU memory kind test.") if name == "named_sharding": mesh = jtu.create_global_mesh((8,), ("x",)) - NamedSharding(mesh, P("x"), memory_kind="device") + NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) elif name == "positional_sharding": - PositionalSharding(jax.devices(), memory_kind="device") + PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) elif name == "single_device_sharding": SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") else: @@ -159,19 +166,19 @@ class ShardingMemoriesTest(jtu.JaxTestCase): if name == "named_sharding": mesh = jtu.create_global_mesh((8,), ("x",)) s1 = NamedSharding(mesh, P("x")) - s2 = NamedSharding(mesh, P("x"), memory_kind="device") + s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "positional_sharding": s1 = PositionalSharding(jax.devices()) - s2 = PositionalSharding(jax.devices(), memory_kind="device") + s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "single_device_sharding": s1 = SingleDeviceSharding(jax.devices()[0]) - s2 = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) elif name == "gspmd_sharding": s1 = GSPMDSharding.get_replicated(jax.devices()) - s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind="device") + s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) def test_sharding_equivalent(self): @@ -181,11 +188,11 @@ class ShardingMemoriesTest(jtu.JaxTestCase): gs1 = GSPMDSharding( tuple(mesh.devices.flat), ns1._to_xla_hlo_sharding(ndim), - memory_kind="device", + memory_kind=self._default_memory_kind, ) self.assertTrue(ns1.is_equivalent_to(gs1, ndim)) - ns2 = NamedSharding(mesh, P("x"), memory_kind="device") + ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) gs2 = GSPMDSharding( tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim) ) @@ -193,7 +200,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase): def test_default_memory_kind(self): dev = jax.devices()[0] - self.assertEqual(dev.default_memory().kind, "device") + self.assertEqual(dev.default_memory().kind, self._default_memory_kind) class MemoriesComputationTest(jtu.BufferDonationTestCase):