mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Small cleanup of xmap tests
This commit is contained in:
parent
bd8d4a34da
commit
692e31c924
@ -89,6 +89,14 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
def with_mesh_from_kwargs(f):
|
||||
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
||||
|
||||
def with_and_without_mesh(f):
|
||||
return parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
|
||||
# -------------------- Itertools helpers --------------------
|
||||
|
||||
@ -207,10 +215,31 @@ def schedules(sizes: Dict[str, int]
|
||||
yield axis_resources, mesh_data
|
||||
|
||||
|
||||
class XMapTest(jtu.BufferDonationTestCase):
|
||||
class XMapTestCase(jtu.BufferDonationTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
super().setUp()
|
||||
|
||||
|
||||
# A mixin that enables SPMD lowering tests
|
||||
class SPMDTestMixin:
|
||||
def setUp(self):
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest
|
||||
super().setUp()
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
|
||||
|
||||
def tearDown(self):
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
|
||||
|
||||
|
||||
class XMapTest(XMapTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testBasic(self):
|
||||
@ -354,13 +383,7 @@ class XMapTest(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(y[0].sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))
|
||||
@with_mesh_from_kwargs
|
||||
@with_and_without_mesh
|
||||
@ignore_xmap_warning()
|
||||
def testMultipleCalls(self, mesh, axis_resources):
|
||||
def f(x, y):
|
||||
@ -376,13 +399,7 @@ class XMapTest(jtu.BufferDonationTestCase):
|
||||
for i in range(10):
|
||||
self.assertAllClose(f_mapped(x, x), expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))
|
||||
@with_mesh_from_kwargs
|
||||
@with_and_without_mesh
|
||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||
@ignore_xmap_warning()
|
||||
def testBufferDonation(self, mesh, axis_resources):
|
||||
@ -401,13 +418,7 @@ class XMapTest(jtu.BufferDonationTestCase):
|
||||
self.assertNotDeleted(y)
|
||||
self.assertDeleted(x)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))
|
||||
@with_mesh_from_kwargs
|
||||
@with_and_without_mesh
|
||||
@ignore_xmap_warning()
|
||||
def testAxisSizes(self, mesh, axis_resources):
|
||||
result = xmap(lambda: lax.axis_index('i'),
|
||||
@ -500,8 +511,8 @@ class XMapTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(fm(x, y), fref(x, y))
|
||||
|
||||
|
||||
class XMapTestSPMD(XMapTest):
|
||||
"""Re-executes all tests with the SPMD partitioner enabled"""
|
||||
class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
"""Re-executes all basic tests with the SPMD partitioner enabled"""
|
||||
|
||||
skipped_tests = {
|
||||
"NestedMesh", # Nesting xmap calls is not supported in the SPMD lowering yet
|
||||
@ -510,25 +521,13 @@ class XMapTestSPMD(XMapTest):
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest
|
||||
for skipped_name in self.skipped_tests:
|
||||
if skipped_name in self._testMethodName:
|
||||
raise SkipTest
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
|
||||
|
||||
def tearDown(self):
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
|
||||
super().setUp()
|
||||
|
||||
|
||||
class NamedNumPyTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
class NamedNumPyTest(XMapTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{reduction.__name__}_axes={axes}_i={mapped_axis}",
|
||||
@ -554,12 +553,7 @@ class NamedNumPyTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ref_red(x), xmap_red(x))
|
||||
|
||||
|
||||
class NamedRandomTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
class NamedRandomTest(XMapTestCase):
|
||||
|
||||
@curry
|
||||
def parameterize_by_sampler(extra, f, subset):
|
||||
@ -606,12 +600,7 @@ class NamedRandomTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(sample({}), sample(dict(axis_resources)))
|
||||
|
||||
|
||||
class NamedNNTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
class NamedNNTest(XMapTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testOneHot(self):
|
||||
@ -768,12 +757,7 @@ def schedules_from_pdot_spec(
|
||||
yield from schedules(logical_sizes)
|
||||
|
||||
|
||||
class PDotTests(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
super().setUp()
|
||||
class PDotTests(XMapTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('r1', 2)])
|
||||
|
Loading…
x
Reference in New Issue
Block a user