From 89b54498825ed329b682acf2deb02b9cd42515e6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 Oct 2023 16:01:34 -0700 Subject: [PATCH] [XLA:GPU] Fix bug in all-to-all for complex data types. The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier. Fixes https://github.com/google/jax/issues/18122 PiperOrigin-RevId: 573955671 --- tests/BUILD | 3 +++ tests/pmap_test.py | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 51e0f769a..9b50e0c78 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -637,6 +637,9 @@ jax_test( "tpu": 30, }, tags = ["multiaccelerator"], + deps = [ + "//jax:internal_test_util", + ], ) jax_test( diff --git a/tests/pmap_test.py b/tests/pmap_test.py index afbb0254e..2ae99c850 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -47,6 +47,7 @@ from jax._src import sharding_impls from jax._src import sharding_specs from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.internal_test_util import lax_test_util from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lax import parallel @@ -572,14 +573,20 @@ class PythonPmapTest(jtu.JaxTestCase): ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) - @parameterized.named_parameters( - {"testcase_name": f"_split={split_axis}_concat={concat_axis}", - "split_axis": split_axis, "concat_axis": concat_axis} - for split_axis, concat_axis in it.product(range(2), range(2))) - def testAllToAll(self, split_axis, concat_axis): + @jtu.sample_product( + split_axis=list(range(2)), + concat_axis=list(range(2)), + dtype=lax_test_util.all_dtypes, + ) + def testAllToAll(self, split_axis, concat_axis, dtype): + if xla_extension_version < 207 and jnp.issubdtype( + dtype, jnp.complexfloating + ): + raise unittest.SkipTest('Test requires jaxlib 0.4.19') pmap_in_axis = 0 shape = (jax.device_count(),) * 3 - x = np.arange(math.prod(shape)).reshape(shape) + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) @partial(self.pmap, axis_name='i') def f(x): @@ -871,7 +878,6 @@ class PythonPmapTest(jtu.JaxTestCase): g = jit(lambda z: z + y) self.assertAllClose(g(7), y + 7) - # Tests edge cases in lax._reshape_sharded_device_array @parameterized.named_parameters( {"testcase_name": f"_in={in_shape}_out={out_shape}" @@ -1883,7 +1889,6 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertIn(f"mhlo.num_replicas = {2}", hlo) self.assertIn("mhlo.num_partitions = 1", hlo) - def testPsumZeroCotangents(self): # https://github.com/google/jax/issues/3651 def loss(params, meta_params):