Small cleanup of xmap tests

This commit is contained in:
Adam Paszke 2021-02-08 12:18:33 +00:00
parent bd8d4a34da
commit 692e31c924

View File

@ -89,6 +89,14 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
def with_and_without_mesh(f):
return parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
# -------------------- Itertools helpers --------------------
@ -207,10 +215,31 @@ def schedules(sizes: Dict[str, int]
yield axis_resources, mesh_data
class XMapTest(jtu.BufferDonationTestCase):
class XMapTestCase(jtu.BufferDonationTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
# A mixin that enables SPMD lowering tests
class SPMDTestMixin:
def setUp(self):
if jtu.device_under_test() != "tpu":
raise SkipTest
super().setUp()
jax.experimental.maps.make_xmap_callable.cache_clear()
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
def tearDown(self):
jax.experimental.maps.make_xmap_callable.cache_clear()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
class XMapTest(XMapTestCase):
@ignore_xmap_warning()
def testBasic(self):
@ -354,13 +383,7 @@ class XMapTest(jtu.BufferDonationTestCase):
self.assertEqual(y[0].sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))
@with_mesh_from_kwargs
@with_and_without_mesh
@ignore_xmap_warning()
def testMultipleCalls(self, mesh, axis_resources):
def f(x, y):
@ -376,13 +399,7 @@ class XMapTest(jtu.BufferDonationTestCase):
for i in range(10):
self.assertAllClose(f_mapped(x, x), expected)
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))
@with_mesh_from_kwargs
@with_and_without_mesh
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
@ignore_xmap_warning()
def testBufferDonation(self, mesh, axis_resources):
@ -401,13 +418,7 @@ class XMapTest(jtu.BufferDonationTestCase):
self.assertNotDeleted(y)
self.assertDeleted(x)
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))
@with_mesh_from_kwargs
@with_and_without_mesh
@ignore_xmap_warning()
def testAxisSizes(self, mesh, axis_resources):
result = xmap(lambda: lax.axis_index('i'),
@ -500,8 +511,8 @@ class XMapTest(jtu.BufferDonationTestCase):
self.assertAllClose(fm(x, y), fref(x, y))
class XMapTestSPMD(XMapTest):
"""Re-executes all tests with the SPMD partitioner enabled"""
class XMapTestSPMD(SPMDTestMixin, XMapTest):
"""Re-executes all basic tests with the SPMD partitioner enabled"""
skipped_tests = {
"NestedMesh", # Nesting xmap calls is not supported in the SPMD lowering yet
@ -510,25 +521,13 @@ class XMapTestSPMD(XMapTest):
}
def setUp(self):
super().setUp()
if jtu.device_under_test() != "tpu":
raise SkipTest
for skipped_name in self.skipped_tests:
if skipped_name in self._testMethodName:
raise SkipTest
jax.experimental.maps.make_xmap_callable.cache_clear()
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
def tearDown(self):
jax.experimental.maps.make_xmap_callable.cache_clear()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
super().setUp()
class NamedNumPyTest(jtu.JaxTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
class NamedNumPyTest(XMapTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{reduction.__name__}_axes={axes}_i={mapped_axis}",
@ -554,12 +553,7 @@ class NamedNumPyTest(jtu.JaxTestCase):
self.assertAllClose(ref_red(x), xmap_red(x))
class NamedRandomTest(jtu.JaxTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
class NamedRandomTest(XMapTestCase):
@curry
def parameterize_by_sampler(extra, f, subset):
@ -606,12 +600,7 @@ class NamedRandomTest(jtu.JaxTestCase):
self.assertAllClose(sample({}), sample(dict(axis_resources)))
class NamedNNTest(jtu.JaxTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
class NamedNNTest(XMapTestCase):
@ignore_xmap_warning()
def testOneHot(self):
@ -768,12 +757,7 @@ def schedules_from_pdot_spec(
yield from schedules(logical_sizes)
class PDotTests(jtu.JaxTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
class PDotTests(XMapTestCase):
@ignore_xmap_warning()
@with_mesh([('r1', 2)])