[JAX] Handle non-contiguous GPU IDs in NCCL collectives.

Fixes https://github.com/google/jax/issues/12119

PiperOrigin-RevId: 470335156
This commit is contained in:
Peter Hawkins 2022-08-26 14:32:32 -07:00 committed by jax authors
parent 62057027bc
commit b9d7e05eda

View File

@ -109,7 +109,7 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
@unittest.skipIf(xla_extension_version < 88,
@unittest.skipIf(xla_extension_version < 91,
"Test requires jaxlib 0.3.17 or newer")
def test_distributed_jax_cuda_visible_devices(self):
"""Test jax_cuda_visible_devices works in distributed settings."""
@ -136,7 +136,8 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
'jax.distributed.initialize('
'f\'localhost:{os.environ["JAX_PORT"]}\', '
'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\', end="")'
's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); '
'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); '
)
args = [sys.executable, "-c", program]
subprocesses.append(subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
@ -145,7 +146,7 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
for proc in subprocesses:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus},[{num_gpus}.]')
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",