mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #12437 from sudhakarsingh27:add_multi_host_pjit_tests
PiperOrigin-RevId: 476451469
This commit is contained in:
commit
e76aa77895
@ -17,16 +17,23 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import experimental
|
||||
from jax.config import config
|
||||
from jax._src import distributed
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
@ -40,7 +47,6 @@ except ImportError:
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@unittest.skipIf(not portpicker, "Test requires portpicker")
|
||||
class DistributedTest(jtu.JaxTestCase):
|
||||
|
||||
@ -170,6 +176,49 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
if pytest is not None:
|
||||
pytestmark = pytest.mark.SlurmMultiNodeGpuTest
|
||||
|
||||
def sorted_devices(self):
|
||||
devices = sorted(jax.devices(), key=lambda d: (d.id, d.host_id))
|
||||
if len(devices) != 16:
|
||||
raise unittest.SkipTest(
|
||||
"Test assumes that it runs on 16 devices (2 nodes)")
|
||||
return devices
|
||||
|
||||
def create_2d_non_contiguous_mesh(self):
|
||||
devices = self.sorted_devices()
|
||||
device_mesh = np.array([[devices[0], devices[2]],
|
||||
[devices[4], devices[6]],
|
||||
[devices[1], devices[3]],
|
||||
[devices[5], devices[7]],
|
||||
[devices[8], devices[10]],
|
||||
[devices[12], devices[14]],
|
||||
[devices[9], devices[11]],
|
||||
[devices[13], devices[15]]])
|
||||
# The mesh looks like this (the integers are process index):
|
||||
# 0 2
|
||||
# 4 6
|
||||
# 1 3
|
||||
# 5 7
|
||||
# 8 10
|
||||
# 12 14
|
||||
# 9 11
|
||||
# 13 15
|
||||
assert [d.id for d in device_mesh.flat
|
||||
] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
|
||||
return maps.Mesh(device_mesh, ("x", "y"))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.xmap_spmd_lowering_enabled = jax.config.experimental_xmap_spmd_lowering
|
||||
jax.config.update("experimental_xmap_spmd_lowering", True)
|
||||
self.gda_enabled = jax.config.jax_parallel_functions_output_gda
|
||||
jax.config.update('jax_parallel_functions_output_gda', True)
|
||||
|
||||
def tearDown(self):
|
||||
jax.config.update("experimental_xmap_spmd_lowering",
|
||||
self.xmap_spmd_lowering_enabled)
|
||||
jax.config.update('jax_parallel_functions_output_gda', self.gda_enabled)
|
||||
super().tearDown()
|
||||
|
||||
def test_gpu_multi_node_initialize_and_psum(self):
|
||||
|
||||
# Hookup the ENV vars expected to be set already in the SLURM environment
|
||||
@ -224,5 +273,244 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y[0], jax.device_count())
|
||||
print(y)
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_multi_input_multi_output(self):
|
||||
jax.distributed.initialize()
|
||||
global_mesh = jtu.create_global_mesh((8, 2), ("x", "y"))
|
||||
global_input_shape = (16, 2)
|
||||
global_input_data = np.arange(
|
||||
util.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
mesh_axes1 = experimental.PartitionSpec("x", "y")
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes1, cb)
|
||||
mesh_axes2 = experimental.PartitionSpec("x")
|
||||
gda2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes2, cb)
|
||||
mesh_axes3 = experimental.PartitionSpec(("x", "y"))
|
||||
gda3 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes3, cb)
|
||||
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
|
||||
@functools.partial(
|
||||
pjit.pjit,
|
||||
# `FROM_GDA` will be replicated for all the inputs.
|
||||
in_axis_resources=pjit.FROM_GDA,
|
||||
out_axis_resources=(mesh_axes1, None, mesh_axes2))
|
||||
def f(x, y, z):
|
||||
return x @ x.T, y, z
|
||||
|
||||
out1, out2, out3 = f(gda1, gda2, gda3)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (16, 16))
|
||||
self.assertEqual(out1.local_shards[0].data.shape, (2, 8))
|
||||
self.assertDictEqual(out1.mesh.shape, {"x": 8, "y": 2})
|
||||
expected_matrix_mul = global_input_data @ global_input_data.T
|
||||
for s in out1.local_shards:
|
||||
np.testing.assert_array_equal(np.asarray(s.data),
|
||||
expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (16, 2))
|
||||
self.assertEqual(out2.local_shards[0].data.shape, (16, 2))
|
||||
for s in out2.local_shards:
|
||||
np.testing.assert_array_equal(np.asarray(s.data), global_input_data)
|
||||
|
||||
self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out3.shape, (16, 2))
|
||||
self.assertEqual(out3.local_shards[0].data.shape, (2, 2))
|
||||
for s in out3.local_shards:
|
||||
np.testing.assert_array_equal(np.asarray(s.data),
|
||||
global_input_data[s.index])
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_non_contiguous_mesh(self):
|
||||
jax.distributed.initialize()
|
||||
devices = self.sorted_devices()
|
||||
mesh_devices = np.array(devices[0:8:2] + devices[1:8:2] + devices[8:16:2] +
|
||||
devices[9:16:2])
|
||||
# The device order in the below mesh is:
|
||||
# [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
|
||||
# each having the following process index:
|
||||
# The process-gpu mapping is random: @sudhakarsingh27 to figure out why so
|
||||
# and the data is:
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
||||
global_mesh = maps.Mesh(mesh_devices, ("x",))
|
||||
global_input_shape = (16,)
|
||||
mesh_axes = experimental.PartitionSpec("x")
|
||||
global_input_data = np.arange(
|
||||
util.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
# device_id -> (index, replica_id)
|
||||
expected_idx_rid = {
|
||||
0: ((slice(0, 1),), 0),
|
||||
1: ((slice(4, 5),), 0),
|
||||
2: ((slice(1, 2),), 0),
|
||||
3: ((slice(5, 6),), 0),
|
||||
4: ((slice(2, 3),), 0),
|
||||
5: ((slice(6, 7),), 0),
|
||||
6: ((slice(3, 4),), 0),
|
||||
7: ((slice(7, 8),), 0),
|
||||
8: ((slice(8, 9),), 0),
|
||||
9: ((slice(12, 13),), 0),
|
||||
10: ((slice(9, 10),), 0),
|
||||
11: ((slice(13, 14),), 0),
|
||||
12: ((slice(10, 11),), 0),
|
||||
13: ((slice(14, 15),), 0),
|
||||
14: ((slice(11, 12),), 0),
|
||||
15: ((slice(15, 16),), 0),
|
||||
}
|
||||
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
f = pjit.pjit(lambda x: x,
|
||||
in_axis_resources=pjit.FROM_GDA,
|
||||
out_axis_resources=mesh_axes)
|
||||
out = f(gda1)
|
||||
for s in out.local_shards:
|
||||
device_id = s.device.id
|
||||
expected_index = expected_idx_rid[device_id][0]
|
||||
expected_replica_id = expected_idx_rid[device_id][1]
|
||||
self.assertEqual(s.index, expected_index)
|
||||
self.assertEqual(s.replica_id, expected_replica_id)
|
||||
self.assertEqual(s.data.shape, (1,))
|
||||
np.testing.assert_array_equal(np.asarray(s.data),
|
||||
global_input_data[expected_index])
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_non_contiguous_mesh_2d(self):
|
||||
jax.distributed.initialize()
|
||||
global_mesh = self.create_2d_non_contiguous_mesh()
|
||||
global_input_shape = (16, 2)
|
||||
mesh_axes = experimental.PartitionSpec("x", "y")
|
||||
global_input_data = np.arange(
|
||||
util.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
# device_id -> (index, replica_id)
|
||||
expected_idx_rid = {
|
||||
0: ((slice(0, 2), slice(0, 1)), 0),
|
||||
1: ((slice(4, 6), slice(0, 1)), 0),
|
||||
2: ((slice(0, 2), slice(1, 2)), 0),
|
||||
3: ((slice(4, 6), slice(1, 2)), 0),
|
||||
4: ((slice(2, 4), slice(0, 1)), 0),
|
||||
5: ((slice(6, 8), slice(0, 1)), 0),
|
||||
6: ((slice(2, 4), slice(1, 2)), 0),
|
||||
7: ((slice(6, 8), slice(1, 2)), 0),
|
||||
8: ((slice(8, 10), slice(0, 1)), 0),
|
||||
9: ((slice(12, 14), slice(0, 1)), 0),
|
||||
10: ((slice(8, 10), slice(1, 2)), 0),
|
||||
11: ((slice(12, 14), slice(1, 2)), 0),
|
||||
12: ((slice(10, 12), slice(0, 1)), 0),
|
||||
13: ((slice(14, 16), slice(0, 1)), 0),
|
||||
14: ((slice(10, 12), slice(1, 2)), 0),
|
||||
15: ((slice(14, 16), slice(1, 2)), 0),
|
||||
}
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(lambda x: x,
|
||||
in_axis_resources=pjit.FROM_GDA,
|
||||
out_axis_resources=mesh_axes)
|
||||
out = f(gda1)
|
||||
|
||||
for s in out.local_shards:
|
||||
device_id = s.device.id
|
||||
expected_index = expected_idx_rid[device_id][0]
|
||||
expected_replica_id = expected_idx_rid[device_id][1]
|
||||
self.assertEqual(s.index, expected_index)
|
||||
self.assertEqual(s.replica_id, expected_replica_id)
|
||||
self.assertEqual(s.data.shape, (2, 1))
|
||||
np.testing.assert_array_equal(np.asarray(s.data),
|
||||
global_input_data[expected_index])
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(lambda x: x,
|
||||
in_axis_resources=experimental.PartitionSpec(None),
|
||||
out_axis_resources=mesh_axes)
|
||||
# Fully replicated values allows a non-contiguous mesh.
|
||||
out = f(global_input_data)
|
||||
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=mesh_axes)
|
||||
# Fully replicated values allows a non-contiguous mesh.
|
||||
out = f(global_input_data)
|
||||
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
||||
|
||||
gda2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, experimental.PartitionSpec(None), cb)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(lambda x, y: (x, y),
|
||||
in_axis_resources=(None, None),
|
||||
out_axis_resources=(mesh_axes, mesh_axes))
|
||||
# Fully replicated values + GDA allows a non-contiguous mesh.
|
||||
out1, out2 = f(global_input_data, gda2)
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_non_contiguous_mesh_2d_aot(self):
|
||||
jax.distributed.initialize()
|
||||
global_mesh = self.create_2d_non_contiguous_mesh()
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = experimental.PartitionSpec("x", "y")
|
||||
global_input_data = np.arange(
|
||||
util.prod(global_input_shape)).reshape(global_input_shape)
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes,
|
||||
lambda idx: global_input_data[idx])
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(lambda x, y: (x, y),
|
||||
in_axis_resources=experimental.PartitionSpec("x", "y"),
|
||||
out_axis_resources=experimental.PartitionSpec("x", "y"))
|
||||
inp_aval = jax.ShapedArray((8, 2), jnp.int32)
|
||||
# `ShapedArray` is considered global when lowered and compiled.
|
||||
# Hence it can bypass the contiguous mesh restriction.
|
||||
compiled = f.lower(inp_aval, gda1, _global_avals=True).compile()
|
||||
out1, out2 = compiled(gda1, gda1)
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_eval_shape(self):
|
||||
jax.distributed.initialize()
|
||||
|
||||
with jtu.create_global_mesh((16,), ("x")):
|
||||
|
||||
@functools.partial(pjit.pjit,
|
||||
in_axis_resources=experimental.PartitionSpec(None),
|
||||
out_axis_resources=experimental.PartitionSpec("x"))
|
||||
def f():
|
||||
return jnp.zeros([32, 10])
|
||||
|
||||
self.assertEqual(f().shape, (32, 10))
|
||||
self.assertEqual(jax.eval_shape(f).shape, (32, 10))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user