rocm_jax/tests/multiprocess_gpu_test.py

534 lines
20 KiB
Python
Raw Permalink Normal View History

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import shutil
import subprocess
import sys
import unittest
2022-09-23 12:11:56 -07:00
import functools
from absl.testing import absltest
2022-09-23 12:11:56 -07:00
import numpy as np
import jax
from jax._src import core
from jax._src import test_util as jtu
2022-09-23 12:11:56 -07:00
from jax._src import util
from jax.experimental import pjit
import jax.numpy as jnp
Squashed commit of the following: commit 79b8cbf0cb47e32743e0970bc1abeb6a673866a8 Author: Corey Adams <corey.adams@anl.gov> Date: Mon Jul 1 14:14:15 2024 -0500 Fix mypy issues; change variable name to more universally known name commit 10edc866f568908e536e5c7bd6b59b4e5351781e Author: Corey Adams <corey.adams@anl.gov> Date: Thu Jun 27 13:25:32 2024 -0500 Change copyright year to the year this was authored commit f7086cb44cc98d58a96ae804dcd1787bc31470f7 Author: Corey Adams <corey.adams@anl.gov> Date: Thu Jun 27 13:15:32 2024 -0500 Update build file to include mpi4py cluster. commit 6235eb311b9fca2bd81fe1c49456d164b7332753 Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:11:48 2024 -0500 Update distributed.py Clean up documentation slightly. commit ef3a2e220945b2158cf20edeb1e04bbbf8f290ff Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:09:37 2024 -0500 Update mpi4py_cluster.py Further clean up unneeded comments. commit 6cc07a9a52fc202ecc65c04c513096391c27d02d Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:08:38 2024 -0500 Update mpi4py_cluster.py Remove unneeded commented code. commit 6701bd1a9d645a0e08d95df1692f43946f0a5eb8 Merge: 5a91ac342 98b87540a Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:07:25 2024 -0500 Merge branch 'google:main' into main commit 5a91ac34248afa6f65af3cae66df7d0d122c1d26 Merge: 301bbc67f 6c51234f9 Author: Corey adams <coreyjadams@gmail.com> Date: Tue May 28 22:14:08 2024 -0500 Merge branch 'google:main' into main commit 301bbc67f938bc30c543cf300cec8a9c75f3eef8 Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 11:34:51 2024 -0500 Add test to verify mpi4py based distributed initialization commit 19e66949a36bb0edb4cd66b0f170f42b326928ec Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 11:14:40 2024 -0500 Unify variable naming and fix function argument ordering commit 72fe093042519e48d9c26b7ede3b266c7a850be6 Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 10:56:25 2024 -0500 Remove unmerged code commit 3a96e738a3cdf9b6ed194cb764fa5640a37f6b95 Merge: e4fd97e19 ff3db9b3a Author: Corey adams <coreyjadams@gmail.com> Date: Tue May 28 10:51:41 2024 -0500 Merge branch 'google:main' into main commit e4fd97e197211921fb6911054592041015af94ef Merge: a69729900 72a81e58e Author: Corey adams <coreyjadams@gmail.com> Date: Mon May 13 16:01:35 2024 -0500 Merge branch 'google:main' into main commit a6972990070d5d2f405d5ede9f82d35c7e6d157a Merge: 85bcf42bd 1e48adc69 Author: Corey adams <coreyjadams@gmail.com> Date: Mon May 13 14:21:32 2024 -0500 Merge branch 'google:main' into main commit 85bcf42bdd36ad88a3d287c357cd12fde74c7fc0 Merge: af1a4f0a1 06cd05d1d Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 09:09:31 2024 -0500 Merge branch 'main' of https://github.com/google/jax commit af1a4f0a12008780e9507d1bdd91e9d11ec35916 Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 08:58:33 2024 -0500 update documentation and elaborate on spec_detect_method variable commit 01f4709d5ecd4af675f4fb23d02d6a69b927adac Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 08:45:38 2024 -0500 Address feedback and comments on PR 20174; fix typo in documentation. commit 4f22d86e7358c29ed588267a7d91fe55fb94f143 Merge: 900a0372f 71ec6e33c Author: Corey adams <coreyjadams@gmail.com> Date: Mon Mar 11 11:51:30 2024 -0500 Merge branch 'google:main' into main commit 900a0372f6147d3c9ab53c95b6a4262e5cfe4457 Author: Corey Adams <corey.adams@anl.gov> Date: Mon Mar 11 11:50:48 2024 -0500 Auto-detect of mpi4py-based configuration is now strictly opt-in. commit 1992969da6164e456492fe0f9cd4287f6d8f03cf Author: Corey Adams <corey.adams@anl.gov> Date: Thu Mar 7 12:27:43 2024 -0600 Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
2024-07-02 13:18:05 -05:00
# Used to test for mpi4py installation and skip tests if not installed
import importlib.util
try:
import portpicker
except ImportError:
portpicker = None
jax.config.parse_flags_with_absl()
@unittest.skipIf(not portpicker, "Test requires portpicker")
class MultiProcessGpuTest(jtu.JaxTestCase):
def test_gpu_distributed_initialize(self):
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')
port = portpicker.pick_unused_port()
num_gpus = 4
num_gpus_per_task = 1
num_tasks = num_gpus // num_gpus_per_task
with contextlib.ExitStack() as exit_stack:
subprocesses = []
for task in range(num_tasks):
env = os.environ.copy()
env["JAX_PORT"] = str(port)
env["NUM_TASKS"] = str(num_tasks)
env["TASK"] = str(task)
if jtu.is_device_rocm():
env["HIP_VISIBLE_DEVICES"] = ",".join(
str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task))
else:
env["CUDA_VISIBLE_DEVICES"] = ",".join(
str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task))
args = [
sys.executable,
"-c",
('import jax, os; '
'jax.distributed.initialize('
'f\'localhost:{os.environ["JAX_PORT"]}\', '
'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\', end="")'
)
]
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
subprocesses.append(exit_stack.enter_context(proc))
try:
for proc in subprocesses:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
finally:
for proc in subprocesses:
proc.kill()
def test_distributed_jax_visible_devices(self):
"""Test jax_visible_devices works in distributed settings."""
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')
port = portpicker.pick_unused_port()
num_gpus = 4
num_gpus_per_task = 1
num_tasks = num_gpus // num_gpus_per_task
with contextlib.ExitStack() as exit_stack:
subprocesses = []
for task in range(num_tasks):
env = os.environ.copy()
env["JAX_PORT"] = str(port)
env["NUM_TASKS"] = str(num_tasks)
env["TASK"] = str(task)
visible_devices = ",".join(
str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task))
2022-11-07 09:10:05 -08:00
if jtu.is_device_rocm():
program = (
'import jax, os; '
f'jax.config.update("jax_rocm_visible_devices", "{visible_devices}"); '
'jax.distributed.initialize('
'f\'localhost:{os.environ["JAX_PORT"]}\', '
'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); '
's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); '
'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); '
)
else:
program = (
'import jax, os; '
f'jax.config.update("jax_cuda_visible_devices", "{visible_devices}"); '
'jax.distributed.initialize('
'f\'localhost:{os.environ["JAX_PORT"]}\', '
'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); '
's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); '
'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); '
)
args = [sys.executable, "-c", program]
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
subprocesses.append(exit_stack.enter_context(proc))
try:
for proc in subprocesses:
out, err = proc.communicate()
self.assertEqual(proc.returncode, 0, msg=f"Process failed:\n\n{err}")
self.assertRegex(out, f'{num_gpus_per_task},{num_gpus},\\[{num_gpus}.\\]$')
finally:
for proc in subprocesses:
proc.kill()
def test_gpu_ompi_distributed_initialize(self):
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')
if shutil.which('mpirun') is None:
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
num_gpus = 4
num_gpus_per_task = 1
with contextlib.ExitStack() as exit_stack:
args = [
'mpirun',
'--oversubscribe',
'--allow-run-as-root',
'-n',
str(num_gpus),
sys.executable,
'-c',
('import jax, os; '
'jax.distributed.initialize(); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
)
]
env = os.environ.copy()
# In case the job was launched via Slurm,
# prevent OpenMPI from detecting Slurm environment
env.pop('SLURM_JOBID', None)
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
proc = exit_stack.enter_context(proc)
try:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
Squashed commit of the following: commit 79b8cbf0cb47e32743e0970bc1abeb6a673866a8 Author: Corey Adams <corey.adams@anl.gov> Date: Mon Jul 1 14:14:15 2024 -0500 Fix mypy issues; change variable name to more universally known name commit 10edc866f568908e536e5c7bd6b59b4e5351781e Author: Corey Adams <corey.adams@anl.gov> Date: Thu Jun 27 13:25:32 2024 -0500 Change copyright year to the year this was authored commit f7086cb44cc98d58a96ae804dcd1787bc31470f7 Author: Corey Adams <corey.adams@anl.gov> Date: Thu Jun 27 13:15:32 2024 -0500 Update build file to include mpi4py cluster. commit 6235eb311b9fca2bd81fe1c49456d164b7332753 Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:11:48 2024 -0500 Update distributed.py Clean up documentation slightly. commit ef3a2e220945b2158cf20edeb1e04bbbf8f290ff Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:09:37 2024 -0500 Update mpi4py_cluster.py Further clean up unneeded comments. commit 6cc07a9a52fc202ecc65c04c513096391c27d02d Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:08:38 2024 -0500 Update mpi4py_cluster.py Remove unneeded commented code. commit 6701bd1a9d645a0e08d95df1692f43946f0a5eb8 Merge: 5a91ac342 98b87540a Author: Corey adams <coreyjadams@gmail.com> Date: Thu Jun 27 12:07:25 2024 -0500 Merge branch 'google:main' into main commit 5a91ac34248afa6f65af3cae66df7d0d122c1d26 Merge: 301bbc67f 6c51234f9 Author: Corey adams <coreyjadams@gmail.com> Date: Tue May 28 22:14:08 2024 -0500 Merge branch 'google:main' into main commit 301bbc67f938bc30c543cf300cec8a9c75f3eef8 Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 11:34:51 2024 -0500 Add test to verify mpi4py based distributed initialization commit 19e66949a36bb0edb4cd66b0f170f42b326928ec Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 11:14:40 2024 -0500 Unify variable naming and fix function argument ordering commit 72fe093042519e48d9c26b7ede3b266c7a850be6 Author: Corey Adams <corey.adams@anl.gov> Date: Tue May 28 10:56:25 2024 -0500 Remove unmerged code commit 3a96e738a3cdf9b6ed194cb764fa5640a37f6b95 Merge: e4fd97e19 ff3db9b3a Author: Corey adams <coreyjadams@gmail.com> Date: Tue May 28 10:51:41 2024 -0500 Merge branch 'google:main' into main commit e4fd97e197211921fb6911054592041015af94ef Merge: a69729900 72a81e58e Author: Corey adams <coreyjadams@gmail.com> Date: Mon May 13 16:01:35 2024 -0500 Merge branch 'google:main' into main commit a6972990070d5d2f405d5ede9f82d35c7e6d157a Merge: 85bcf42bd 1e48adc69 Author: Corey adams <coreyjadams@gmail.com> Date: Mon May 13 14:21:32 2024 -0500 Merge branch 'google:main' into main commit 85bcf42bdd36ad88a3d287c357cd12fde74c7fc0 Merge: af1a4f0a1 06cd05d1d Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 09:09:31 2024 -0500 Merge branch 'main' of https://github.com/google/jax commit af1a4f0a12008780e9507d1bdd91e9d11ec35916 Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 08:58:33 2024 -0500 update documentation and elaborate on spec_detect_method variable commit 01f4709d5ecd4af675f4fb23d02d6a69b927adac Author: Corey Adams <corey.adams@anl.gov> Date: Tue Apr 16 08:45:38 2024 -0500 Address feedback and comments on PR 20174; fix typo in documentation. commit 4f22d86e7358c29ed588267a7d91fe55fb94f143 Merge: 900a0372f 71ec6e33c Author: Corey adams <coreyjadams@gmail.com> Date: Mon Mar 11 11:51:30 2024 -0500 Merge branch 'google:main' into main commit 900a0372f6147d3c9ab53c95b6a4262e5cfe4457 Author: Corey Adams <corey.adams@anl.gov> Date: Mon Mar 11 11:50:48 2024 -0500 Auto-detect of mpi4py-based configuration is now strictly opt-in. commit 1992969da6164e456492fe0f9cd4287f6d8f03cf Author: Corey Adams <corey.adams@anl.gov> Date: Thu Mar 7 12:27:43 2024 -0600 Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
2024-07-02 13:18:05 -05:00
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
finally:
proc.kill()
def test_gpu_mpi4py_distributed_initialize(self):
if not jtu.test_device_matches(['gpu']):
raise unittest.SkipTest('Tests only for GPU.')
if shutil.which('mpirun') is None:
raise unittest.SkipTest('Tests only for MPI (mpirun not found).')
if importlib.util.find_spec("mpi4py") is None:
raise unittest.SkipTest('Test of mpi4py initialize only possible with mpi4py installed.')
num_gpus = 4
num_gpus_per_task = 1
with contextlib.ExitStack() as exit_stack:
args = [
'mpirun',
'--oversubscribe',
'--allow-run-as-root',
'-n',
str(num_gpus),
sys.executable,
'-c',
('import jax, os; '
'jax.distributed.initialize(spec_detection_method="mpi4py"); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\' if jax.process_index() == 0 else \'\', end="")'
)
]
env = os.environ.copy()
# In case the job was launched via Slurm,
# prevent OpenMPI from detecting Slurm environment
env.pop('SLURM_JOBID', None)
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
proc = exit_stack.enter_context(proc)
try:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')
finally:
proc.kill()
2022-08-25 15:27:07 -07:00
@unittest.skipIf(
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
@jtu.pytest_mark_if_available('SlurmMultiNodeGpuTest')
class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
2022-08-25 15:27:07 -07:00
2022-09-23 12:11:56 -07:00
def sorted_devices(self):
devices = sorted(jax.devices(), key=lambda d: (d.id, d.host_id))
if len(devices) != 16:
raise unittest.SkipTest(
"Test assumes that it runs on 16 devices (2 nodes)")
return devices
def create_2d_non_contiguous_mesh(self):
devices = self.sorted_devices()
device_mesh = np.array([[devices[0], devices[2]],
[devices[4], devices[6]],
[devices[1], devices[3]],
[devices[5], devices[7]],
[devices[8], devices[10]],
[devices[12], devices[14]],
[devices[9], devices[11]],
[devices[13], devices[15]]])
# The mesh looks like this (the integers are process index):
# 0 2
# 4 6
# 1 3
# 5 7
# 8 10
# 12 14
# 9 11
# 13 15
assert [d.id for d in device_mesh.flat
] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
return jax.sharding.Mesh(device_mesh, ("x", "y"))
2022-09-23 12:11:56 -07:00
2022-08-25 15:27:07 -07:00
def test_gpu_multi_node_initialize_and_psum(self):
# Hookup the ENV vars expected to be set already in the SLURM environment
coordinator_address = os.environ.get("SLURM_STEP_NODELIST", None)
if coordinator_address is not None and '[' in coordinator_address:
coordinator_address = coordinator_address.split('[')[0] + \
coordinator_address.split('[')[1].split(',')[0]
2022-08-25 15:27:07 -07:00
num_tasks = os.environ.get("SLURM_NPROCS", None)
taskid = os.environ.get("SLURM_PROCID", None)
localid = os.environ.get("SLURM_LOCALID", None)
# fixing port since it needs to be the same for all the processes
port = "54321"
print(f"coord addr:port : {coordinator_address}:{port}\nTotal tasks: "
f"{num_tasks}\ntask id: {taskid}\nlocal id: {localid}")
self.assertEqual(
coordinator_address is None or num_tasks is None or taskid is None,
False)
# os.environ["CUDA_VISIBLE_DEVICES"] = localid #WAR for Bug:12119
jax.config.update("jax_cuda_visible_devices", localid)
2022-08-25 15:27:07 -07:00
jax.distributed.initialize(coordinator_address=f'{coordinator_address}:{port}',
num_processes=int(num_tasks),
process_id=int(taskid))
print(f"Total devices: {jax.device_count()}, Total tasks: {int(num_tasks)}, "
f"Devices per task: {jax.local_device_count()}")
self.assertEqual(jax.device_count(),
int(num_tasks) * jax.local_device_count())
x = jnp.ones(jax.local_device_count())
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
self.assertEqual(y[0], jax.device_count())
print(y)
def test_gpu_multi_node_transparent_initialize_and_psum(self):
jax.distributed.initialize()
print(f"Total devices: {jax.device_count()}, "
f"Devices per task: {jax.local_device_count()}")
self.assertEqual(jax.device_count(), int(os.environ['SLURM_NTASKS']))
self.assertEqual(jax.local_device_count(), 1)
x = jnp.ones(jax.local_device_count())
y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)
self.assertEqual(y[0], jax.device_count())
print(y)
2022-09-23 12:11:56 -07:00
def test_pjit_gda_multi_input_multi_output(self):
jax.distributed.initialize()
global_mesh = jtu.create_mesh((8, 2), ("x", "y"))
2022-09-23 12:11:56 -07:00
global_input_shape = (16, 2)
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
mesh_axes1 = jax.sharding.PartitionSpec("x", "y")
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes1), cb)
mesh_axes2 = jax.sharding.PartitionSpec("x")
gda2 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes2), cb)
mesh_axes3 = jax.sharding.PartitionSpec(("x", "y"))
gda3 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes3), cb)
2022-09-23 12:11:56 -07:00
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
2022-09-23 12:11:56 -07:00
@functools.partial(
pjit.pjit,
out_shardings=(mesh_axes1, None, mesh_axes2))
2022-09-23 12:11:56 -07:00
def f(x, y, z):
return x @ x.T, y, z
out1, out2, out3 = f(gda1, gda2, gda3)
self.assertEqual(out1.shape, (16, 16))
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 8))
2022-09-23 12:11:56 -07:00
expected_matrix_mul = global_input_data @ global_input_data.T
for s in out1.addressable_shards:
2022-09-23 12:11:56 -07:00
np.testing.assert_array_equal(np.asarray(s.data),
expected_matrix_mul[s.index])
self.assertEqual(out2.shape, (16, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, (16, 2))
for s in out2.addressable_shards:
2022-09-23 12:11:56 -07:00
np.testing.assert_array_equal(np.asarray(s.data), global_input_data)
self.assertEqual(out3.shape, (16, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, (2, 2))
for s in out3.addressable_shards:
2022-09-23 12:11:56 -07:00
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[s.index])
def test_pjit_gda_non_contiguous_mesh(self):
jax.distributed.initialize()
devices = self.sorted_devices()
mesh_devices = np.array(devices[0:8:2] + devices[1:8:2] + devices[8:16:2] +
devices[9:16:2])
# The device order in the below mesh is:
# [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
# each having the following process index:
# The process-gpu mapping is random: @sudhakarsingh27 to figure out why so
# and the data is:
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
global_mesh = jax.sharding.Mesh(mesh_devices, ("x",))
2022-09-23 12:11:56 -07:00
global_input_shape = (16,)
mesh_axes = jax.sharding.PartitionSpec("x")
2022-09-23 12:11:56 -07:00
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), cb)
2022-09-23 12:11:56 -07:00
# device_id -> (index, replica_id)
expected_idx_rid = {
0: ((slice(0, 1),), 0),
1: ((slice(4, 5),), 0),
2: ((slice(1, 2),), 0),
3: ((slice(5, 6),), 0),
4: ((slice(2, 3),), 0),
5: ((slice(6, 7),), 0),
6: ((slice(3, 4),), 0),
7: ((slice(7, 8),), 0),
8: ((slice(8, 9),), 0),
9: ((slice(12, 13),), 0),
10: ((slice(9, 10),), 0),
11: ((slice(13, 14),), 0),
12: ((slice(10, 11),), 0),
13: ((slice(14, 15),), 0),
14: ((slice(11, 12),), 0),
15: ((slice(15, 16),), 0),
}
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
2022-09-23 12:11:56 -07:00
out = f(gda1)
for s in out.addressable_shards:
2022-09-23 12:11:56 -07:00
device_id = s.device.id
expected_index = expected_idx_rid[device_id][0]
expected_replica_id = expected_idx_rid[device_id][1]
self.assertEqual(s.index, expected_index)
self.assertEqual(s.replica_id, expected_replica_id)
self.assertEqual(s.data.shape, (1,))
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[expected_index])
def test_pjit_gda_non_contiguous_mesh_2d(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
global_input_shape = (16, 2)
mesh_axes = jax.sharding.PartitionSpec("x", "y")
2022-09-23 12:11:56 -07:00
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), cb)
2022-09-23 12:11:56 -07:00
# device_id -> (index, replica_id)
expected_idx_rid = {
0: ((slice(0, 2), slice(0, 1)), 0),
1: ((slice(4, 6), slice(0, 1)), 0),
2: ((slice(0, 2), slice(1, 2)), 0),
3: ((slice(4, 6), slice(1, 2)), 0),
4: ((slice(2, 4), slice(0, 1)), 0),
5: ((slice(6, 8), slice(0, 1)), 0),
6: ((slice(2, 4), slice(1, 2)), 0),
7: ((slice(6, 8), slice(1, 2)), 0),
8: ((slice(8, 10), slice(0, 1)), 0),
9: ((slice(12, 14), slice(0, 1)), 0),
10: ((slice(8, 10), slice(1, 2)), 0),
11: ((slice(12, 14), slice(1, 2)), 0),
12: ((slice(10, 12), slice(0, 1)), 0),
13: ((slice(14, 16), slice(0, 1)), 0),
14: ((slice(10, 12), slice(1, 2)), 0),
15: ((slice(14, 16), slice(1, 2)), 0),
}
with global_mesh:
f = pjit.pjit(lambda x: x, out_shardings=mesh_axes)
2022-09-23 12:11:56 -07:00
out = f(gda1)
for s in out.addressable_shards:
2022-09-23 12:11:56 -07:00
device_id = s.device.id
expected_index = expected_idx_rid[device_id][0]
expected_replica_id = expected_idx_rid[device_id][1]
self.assertEqual(s.index, expected_index)
self.assertEqual(s.replica_id, expected_replica_id)
self.assertEqual(s.data.shape, (2, 1))
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[expected_index])
with global_mesh:
f = pjit.pjit(
lambda x: x,
in_shardings=jax.sharding.PartitionSpec(None),
out_shardings=mesh_axes,
)
2022-09-23 12:11:56 -07:00
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
with global_mesh:
f = pjit.pjit(lambda x: x, in_shardings=None, out_shardings=mesh_axes)
2022-09-23 12:11:56 -07:00
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
gda2 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, jax.sharding.PartitionSpec(None)), cb)
2022-09-23 12:11:56 -07:00
with global_mesh:
f = pjit.pjit(
lambda x, y: (x, y),
in_shardings=(None, None),
out_shardings=(mesh_axes, mesh_axes),
)
2022-09-23 12:11:56 -07:00
# Fully replicated values + GDA allows a non-contiguous mesh.
out1, out2 = f(global_input_data, gda2)
def test_pjit_gda_non_contiguous_mesh_2d_aot(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
global_input_shape = (8, 2)
mesh_axes = jax.sharding.PartitionSpec("x", "y")
2022-09-23 12:11:56 -07:00
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes),
2022-09-23 12:11:56 -07:00
lambda idx: global_input_data[idx])
with global_mesh:
f = pjit.pjit(
lambda x, y: (x, y),
in_shardings=jax.sharding.PartitionSpec("x", "y"),
out_shardings=jax.sharding.PartitionSpec("x", "y"),
)
inp_aval = core.ShapedArray((8, 2), jnp.int32)
2022-09-23 12:11:56 -07:00
# `ShapedArray` is considered global when lowered and compiled.
# Hence it can bypass the contiguous mesh restriction.
compiled = f.lower(inp_aval, gda1).compile()
2022-09-23 12:11:56 -07:00
out1, out2 = compiled(gda1, gda1)
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out2.shape, (8, 2))
def test_pjit_gda_eval_shape(self):
jax.distributed.initialize()
with jtu.create_mesh((16,), ("x")):
2022-09-23 12:11:56 -07:00
@functools.partial(pjit.pjit,
in_shardings=jax.sharding.PartitionSpec(None),
out_shardings=jax.sharding.PartitionSpec("x"))
2022-09-23 12:11:56 -07:00
def f():
return jnp.zeros([32, 10])
self.assertEqual(f().shape, (32, 10))
self.assertEqual(jax.eval_shape(f).shape, (32, 10))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())