mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[XLA:GPU] Fix bug in all-to-all for complex data types.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier. Fixes https://github.com/google/jax/issues/18122 PiperOrigin-RevId: 573955671
This commit is contained in:
parent
5919c1f33c
commit
89b5449882
@ -637,6 +637,9 @@ jax_test(
|
||||
"tpu": 30,
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:internal_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -47,6 +47,7 @@ from jax._src import sharding_impls
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lax import parallel
|
||||
@ -572,14 +573,20 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis}
|
||||
for split_axis, concat_axis in it.product(range(2), range(2)))
|
||||
def testAllToAll(self, split_axis, concat_axis):
|
||||
@jtu.sample_product(
|
||||
split_axis=list(range(2)),
|
||||
concat_axis=list(range(2)),
|
||||
dtype=lax_test_util.all_dtypes,
|
||||
)
|
||||
def testAllToAll(self, split_axis, concat_axis, dtype):
|
||||
if xla_extension_version < 207 and jnp.issubdtype(
|
||||
dtype, jnp.complexfloating
|
||||
):
|
||||
raise unittest.SkipTest('Test requires jaxlib 0.4.19')
|
||||
pmap_in_axis = 0
|
||||
shape = (jax.device_count(),) * 3
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
|
||||
@partial(self.pmap, axis_name='i')
|
||||
def f(x):
|
||||
@ -871,7 +878,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
g = jit(lambda z: z + y)
|
||||
self.assertAllClose(g(7), y + 7)
|
||||
|
||||
|
||||
# Tests edge cases in lax._reshape_sharded_device_array
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_in={in_shape}_out={out_shape}"
|
||||
@ -1883,7 +1889,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertIn(f"mhlo.num_replicas = {2}", hlo)
|
||||
self.assertIn("mhlo.num_partitions = 1", hlo)
|
||||
|
||||
|
||||
def testPsumZeroCotangents(self):
|
||||
# https://github.com/google/jax/issues/3651
|
||||
def loss(params, meta_params):
|
||||
|
Loading…
x
Reference in New Issue
Block a user