Enable JAX memory tests for GPUs and CPUs

PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now.

PiperOrigin-RevId: 633335638
This commit is contained in:
Junwhan Ahn 2024-05-13 14:36:43 -07:00 committed by jax authors
parent 72a81e58e6
commit cd6e012326

View File

@ -24,6 +24,7 @@ from jax import lax
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src import config
from jax._src.lib import xla_extension_version
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
@ -64,7 +65,7 @@ def _create_inputs(shape, pspec, mem_kind=None):
class ShardingMemoriesTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu"]):
if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
# TODO(b/311021572)
if jtu.is_cloud_tpu():
@ -72,6 +73,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
super().setUp()
self.orig_memories_flag = config.enable_memories.value
jax.config.update('jax_enable_memories', True)
if jtu.test_device_matches(["cpu"]):
self._default_memory_kind = "unpinned_host"
else:
self._default_memory_kind = "device"
def tearDown(self):
jax.config.update('jax_enable_memories', self.orig_memories_flag)
@ -87,17 +92,17 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
if name == "named_sharding":
mesh = jtu.create_global_mesh((1,), "x")
ns = NamedSharding(mesh, P("x"))
self.assertEqual(ns.memory_kind, "device")
self.assertEqual(ns.memory_kind, self._default_memory_kind)
elif name == "positional_sharding":
ps = PositionalSharding(jax.devices())
self.assertEqual(ps.memory_kind, "device")
self.assertEqual(ps.memory_kind, self._default_memory_kind)
elif name == "single_device_sharding":
ss = SingleDeviceSharding(jax.devices()[0])
self.assertEqual(ss.memory_kind, "device")
self.assertEqual(ss.memory_kind, self._default_memory_kind)
else:
assert name == "gspmd_sharding"
gs = GSPMDSharding.get_replicated(jax.devices())
self.assertEqual(gs.memory_kind, "device")
self.assertEqual(gs.memory_kind, self._default_memory_kind)
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
@ -108,26 +113,26 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
def test_wrong_memory_kind(self, name):
if name == "named_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
ValueError, "Could not find memory addressable by device.*"
):
mesh = jtu.create_global_mesh((8,), ("x",))
NamedSharding(mesh, P("x"), memory_kind="hbm")
elif name == "positional_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
ValueError, "Could not find memory addressable by device.*"
):
PositionalSharding(jax.devices(), memory_kind="gpu_hbm")
elif name == "single_device_sharding":
with self.assertRaisesRegex(
ValueError,
"Could not find memory addressable by device TPU.*Device TPU.*"
"Could not find memory addressable by device.*Device.*"
" can address the following memory kinds.*",
):
SingleDeviceSharding(jax.devices()[0], memory_kind="host")
else:
assert name == "gspmd_sharding"
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
ValueError, "Could not find memory addressable by device.*"
):
GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host")
@ -138,11 +143,13 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
("gspmd_sharding", "gspmd_sharding"),
)
def test_correct_tpu_memory_kind(self, name):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("TPU memory kind test.")
if name == "named_sharding":
mesh = jtu.create_global_mesh((8,), ("x",))
NamedSharding(mesh, P("x"), memory_kind="device")
NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
elif name == "positional_sharding":
PositionalSharding(jax.devices(), memory_kind="device")
PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
elif name == "single_device_sharding":
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
else:
@ -159,19 +166,19 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
if name == "named_sharding":
mesh = jtu.create_global_mesh((8,), ("x",))
s1 = NamedSharding(mesh, P("x"))
s2 = NamedSharding(mesh, P("x"), memory_kind="device")
s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "positional_sharding":
s1 = PositionalSharding(jax.devices())
s2 = PositionalSharding(jax.devices(), memory_kind="device")
s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "single_device_sharding":
s1 = SingleDeviceSharding(jax.devices()[0])
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "gspmd_sharding":
s1 = GSPMDSharding.get_replicated(jax.devices())
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind="device")
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
def test_sharding_equivalent(self):
@ -181,11 +188,11 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
gs1 = GSPMDSharding(
tuple(mesh.devices.flat),
ns1._to_xla_hlo_sharding(ndim),
memory_kind="device",
memory_kind=self._default_memory_kind,
)
self.assertTrue(ns1.is_equivalent_to(gs1, ndim))
ns2 = NamedSharding(mesh, P("x"), memory_kind="device")
ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
gs2 = GSPMDSharding(
tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim)
)
@ -193,7 +200,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
def test_default_memory_kind(self):
dev = jax.devices()[0]
self.assertEqual(dev.default_memory().kind, "device")
self.assertEqual(dev.default_memory().kind, self._default_memory_kind)
class MemoriesComputationTest(jtu.BufferDonationTestCase):