mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Enable JAX memory tests for GPUs and CPUs
PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now. PiperOrigin-RevId: 633335638
This commit is contained in:
parent
72a81e58e6
commit
cd6e012326
@ -24,6 +24,7 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src import config
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -64,7 +65,7 @@ def _create_inputs(shape, pspec, mem_kind=None):
|
||||
class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
||||
# TODO(b/311021572)
|
||||
if jtu.is_cloud_tpu():
|
||||
@ -72,6 +73,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
self.orig_memories_flag = config.enable_memories.value
|
||||
jax.config.update('jax_enable_memories', True)
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
self._default_memory_kind = "unpinned_host"
|
||||
else:
|
||||
self._default_memory_kind = "device"
|
||||
|
||||
def tearDown(self):
|
||||
jax.config.update('jax_enable_memories', self.orig_memories_flag)
|
||||
@ -87,17 +92,17 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
if name == "named_sharding":
|
||||
mesh = jtu.create_global_mesh((1,), "x")
|
||||
ns = NamedSharding(mesh, P("x"))
|
||||
self.assertEqual(ns.memory_kind, "device")
|
||||
self.assertEqual(ns.memory_kind, self._default_memory_kind)
|
||||
elif name == "positional_sharding":
|
||||
ps = PositionalSharding(jax.devices())
|
||||
self.assertEqual(ps.memory_kind, "device")
|
||||
self.assertEqual(ps.memory_kind, self._default_memory_kind)
|
||||
elif name == "single_device_sharding":
|
||||
ss = SingleDeviceSharding(jax.devices()[0])
|
||||
self.assertEqual(ss.memory_kind, "device")
|
||||
self.assertEqual(ss.memory_kind, self._default_memory_kind)
|
||||
else:
|
||||
assert name == "gspmd_sharding"
|
||||
gs = GSPMDSharding.get_replicated(jax.devices())
|
||||
self.assertEqual(gs.memory_kind, "device")
|
||||
self.assertEqual(gs.memory_kind, self._default_memory_kind)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("named_sharding", "named_sharding"),
|
||||
@ -108,26 +113,26 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
def test_wrong_memory_kind(self, name):
|
||||
if name == "named_sharding":
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Could not find memory addressable by device TPU.*"
|
||||
ValueError, "Could not find memory addressable by device.*"
|
||||
):
|
||||
mesh = jtu.create_global_mesh((8,), ("x",))
|
||||
NamedSharding(mesh, P("x"), memory_kind="hbm")
|
||||
elif name == "positional_sharding":
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Could not find memory addressable by device TPU.*"
|
||||
ValueError, "Could not find memory addressable by device.*"
|
||||
):
|
||||
PositionalSharding(jax.devices(), memory_kind="gpu_hbm")
|
||||
elif name == "single_device_sharding":
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Could not find memory addressable by device TPU.*Device TPU.*"
|
||||
"Could not find memory addressable by device.*Device.*"
|
||||
" can address the following memory kinds.*",
|
||||
):
|
||||
SingleDeviceSharding(jax.devices()[0], memory_kind="host")
|
||||
else:
|
||||
assert name == "gspmd_sharding"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Could not find memory addressable by device TPU.*"
|
||||
ValueError, "Could not find memory addressable by device.*"
|
||||
):
|
||||
GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host")
|
||||
|
||||
@ -138,11 +143,13 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
("gspmd_sharding", "gspmd_sharding"),
|
||||
)
|
||||
def test_correct_tpu_memory_kind(self, name):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("TPU memory kind test.")
|
||||
if name == "named_sharding":
|
||||
mesh = jtu.create_global_mesh((8,), ("x",))
|
||||
NamedSharding(mesh, P("x"), memory_kind="device")
|
||||
NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
||||
elif name == "positional_sharding":
|
||||
PositionalSharding(jax.devices(), memory_kind="device")
|
||||
PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
|
||||
elif name == "single_device_sharding":
|
||||
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
|
||||
else:
|
||||
@ -159,19 +166,19 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
if name == "named_sharding":
|
||||
mesh = jtu.create_global_mesh((8,), ("x",))
|
||||
s1 = NamedSharding(mesh, P("x"))
|
||||
s2 = NamedSharding(mesh, P("x"), memory_kind="device")
|
||||
s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
||||
self.assertEqual(s1, s2)
|
||||
elif name == "positional_sharding":
|
||||
s1 = PositionalSharding(jax.devices())
|
||||
s2 = PositionalSharding(jax.devices(), memory_kind="device")
|
||||
s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
|
||||
self.assertEqual(s1, s2)
|
||||
elif name == "single_device_sharding":
|
||||
s1 = SingleDeviceSharding(jax.devices()[0])
|
||||
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
||||
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind)
|
||||
self.assertEqual(s1, s2)
|
||||
elif name == "gspmd_sharding":
|
||||
s1 = GSPMDSharding.get_replicated(jax.devices())
|
||||
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind="device")
|
||||
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind)
|
||||
self.assertEqual(s1, s2)
|
||||
|
||||
def test_sharding_equivalent(self):
|
||||
@ -181,11 +188,11 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
gs1 = GSPMDSharding(
|
||||
tuple(mesh.devices.flat),
|
||||
ns1._to_xla_hlo_sharding(ndim),
|
||||
memory_kind="device",
|
||||
memory_kind=self._default_memory_kind,
|
||||
)
|
||||
self.assertTrue(ns1.is_equivalent_to(gs1, ndim))
|
||||
|
||||
ns2 = NamedSharding(mesh, P("x"), memory_kind="device")
|
||||
ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
||||
gs2 = GSPMDSharding(
|
||||
tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim)
|
||||
)
|
||||
@ -193,7 +200,7 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_default_memory_kind(self):
|
||||
dev = jax.devices()[0]
|
||||
self.assertEqual(dev.default_memory().kind, "device")
|
||||
self.assertEqual(dev.default_memory().kind, self._default_memory_kind)
|
||||
|
||||
|
||||
class MemoriesComputationTest(jtu.BufferDonationTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user