mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX] Handle non-contiguous GPU IDs in NCCL collectives.
Fixes https://github.com/google/jax/issues/12119 PiperOrigin-RevId: 470335156
This commit is contained in:
parent
62057027bc
commit
b9d7e05eda
@ -109,7 +109,7 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
||||
self.assertEqual(proc.returncode, 0)
|
||||
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 88,
|
||||
@unittest.skipIf(xla_extension_version < 91,
|
||||
"Test requires jaxlib 0.3.17 or newer")
|
||||
def test_distributed_jax_cuda_visible_devices(self):
|
||||
"""Test jax_cuda_visible_devices works in distributed settings."""
|
||||
@ -136,7 +136,8 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
||||
'jax.distributed.initialize('
|
||||
'f\'localhost:{os.environ["JAX_PORT"]}\', '
|
||||
'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); '
|
||||
'print(f\'{jax.local_device_count()},{jax.device_count()}\', end="")'
|
||||
's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); '
|
||||
'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); '
|
||||
)
|
||||
args = [sys.executable, "-c", program]
|
||||
subprocesses.append(subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
|
||||
@ -145,7 +146,7 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
||||
for proc in subprocesses:
|
||||
out, _ = proc.communicate()
|
||||
self.assertEqual(proc.returncode, 0)
|
||||
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
|
||||
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus},[{num_gpus}.]')
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||
|
Loading…
x
Reference in New Issue
Block a user