[XLA:GPU] Fix bug in all-to-all for complex data types.

The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.

Fixes https://github.com/google/jax/issues/18122

PiperOrigin-RevId: 573955671
This commit is contained in:
Peter Hawkins 2023-10-16 16:01:34 -07:00 committed by jax authors
parent 5919c1f33c
commit 89b5449882
2 changed files with 16 additions and 8 deletions

View File

@ -637,6 +637,9 @@ jax_test(
"tpu": 30,
},
tags = ["multiaccelerator"],
deps = [
"//jax:internal_test_util",
],
)
jax_test(

View File

@ -47,6 +47,7 @@ from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.internal_test_util import lax_test_util
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lax import parallel
@ -572,14 +573,20 @@ class PythonPmapTest(jtu.JaxTestCase):
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(2), range(2)))
def testAllToAll(self, split_axis, concat_axis):
@jtu.sample_product(
split_axis=list(range(2)),
concat_axis=list(range(2)),
dtype=lax_test_util.all_dtypes,
)
def testAllToAll(self, split_axis, concat_axis, dtype):
if xla_extension_version < 207 and jnp.issubdtype(
dtype, jnp.complexfloating
):
raise unittest.SkipTest('Test requires jaxlib 0.4.19')
pmap_in_axis = 0
shape = (jax.device_count(),) * 3
x = np.arange(math.prod(shape)).reshape(shape)
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
@partial(self.pmap, axis_name='i')
def f(x):
@ -871,7 +878,6 @@ class PythonPmapTest(jtu.JaxTestCase):
g = jit(lambda z: z + y)
self.assertAllClose(g(7), y + 7)
# Tests edge cases in lax._reshape_sharded_device_array
@parameterized.named_parameters(
{"testcase_name": f"_in={in_shape}_out={out_shape}"
@ -1883,7 +1889,6 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIn(f"mhlo.num_replicas = {2}", hlo)
self.assertIn("mhlo.num_partitions = 1", hlo)
def testPsumZeroCotangents(self):
# https://github.com/google/jax/issues/3651
def loss(params, meta_params):