Merge pull request #12437 from sudhakarsingh27:add_multi_host_pjit_tests

PiperOrigin-RevId: 476451469
This commit is contained in:
jax authors 2022-09-23 13:38:59 -07:00
commit e76aa77895

View File

@ -17,16 +17,23 @@ import subprocess
import sys
import threading
import unittest
import functools
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import experimental
from jax.config import config
from jax._src import distributed
import jax.numpy as jnp
from jax._src.lib import xla_extension_version
from jax._src import test_util as jtu
from jax._src import util
from jax.experimental import global_device_array
from jax.experimental import maps
from jax.experimental import pjit
try:
import portpicker
@ -40,7 +47,6 @@ except ImportError:
config.parse_flags_with_absl()
@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):
@ -170,6 +176,49 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
if pytest is not None:
pytestmark = pytest.mark.SlurmMultiNodeGpuTest
def sorted_devices(self):
devices = sorted(jax.devices(), key=lambda d: (d.id, d.host_id))
if len(devices) != 16:
raise unittest.SkipTest(
"Test assumes that it runs on 16 devices (2 nodes)")
return devices
def create_2d_non_contiguous_mesh(self):
devices = self.sorted_devices()
device_mesh = np.array([[devices[0], devices[2]],
[devices[4], devices[6]],
[devices[1], devices[3]],
[devices[5], devices[7]],
[devices[8], devices[10]],
[devices[12], devices[14]],
[devices[9], devices[11]],
[devices[13], devices[15]]])
# The mesh looks like this (the integers are process index):
# 0 2
# 4 6
# 1 3
# 5 7
# 8 10
# 12 14
# 9 11
# 13 15
assert [d.id for d in device_mesh.flat
] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
return maps.Mesh(device_mesh, ("x", "y"))
def setUp(self):
super().setUp()
self.xmap_spmd_lowering_enabled = jax.config.experimental_xmap_spmd_lowering
jax.config.update("experimental_xmap_spmd_lowering", True)
self.gda_enabled = jax.config.jax_parallel_functions_output_gda
jax.config.update('jax_parallel_functions_output_gda', True)
def tearDown(self):
jax.config.update("experimental_xmap_spmd_lowering",
self.xmap_spmd_lowering_enabled)
jax.config.update('jax_parallel_functions_output_gda', self.gda_enabled)
super().tearDown()
def test_gpu_multi_node_initialize_and_psum(self):
# Hookup the ENV vars expected to be set already in the SLURM environment
@ -224,5 +273,244 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
self.assertEqual(y[0], jax.device_count())
print(y)
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_multi_input_multi_output(self):
jax.distributed.initialize()
global_mesh = jtu.create_global_mesh((8, 2), ("x", "y"))
global_input_shape = (16, 2)
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
mesh_axes1 = experimental.PartitionSpec("x", "y")
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes1, cb)
mesh_axes2 = experimental.PartitionSpec("x")
gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes2, cb)
mesh_axes3 = experimental.PartitionSpec(("x", "y"))
gda3 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
@functools.partial(
pjit.pjit,
# `FROM_GDA` will be replicated for all the inputs.
in_axis_resources=pjit.FROM_GDA,
out_axis_resources=(mesh_axes1, None, mesh_axes2))
def f(x, y, z):
return x @ x.T, y, z
out1, out2, out3 = f(gda1, gda2, gda3)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (16, 16))
self.assertEqual(out1.local_shards[0].data.shape, (2, 8))
self.assertDictEqual(out1.mesh.shape, {"x": 8, "y": 2})
expected_matrix_mul = global_input_data @ global_input_data.T
for s in out1.local_shards:
np.testing.assert_array_equal(np.asarray(s.data),
expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (16, 2))
self.assertEqual(out2.local_shards[0].data.shape, (16, 2))
for s in out2.local_shards:
np.testing.assert_array_equal(np.asarray(s.data), global_input_data)
self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
self.assertEqual(out3.shape, (16, 2))
self.assertEqual(out3.local_shards[0].data.shape, (2, 2))
for s in out3.local_shards:
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[s.index])
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh(self):
jax.distributed.initialize()
devices = self.sorted_devices()
mesh_devices = np.array(devices[0:8:2] + devices[1:8:2] + devices[8:16:2] +
devices[9:16:2])
# The device order in the below mesh is:
# [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
# each having the following process index:
# The process-gpu mapping is random: @sudhakarsingh27 to figure out why so
# and the data is:
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
global_mesh = maps.Mesh(mesh_devices, ("x",))
global_input_shape = (16,)
mesh_axes = experimental.PartitionSpec("x")
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
# device_id -> (index, replica_id)
expected_idx_rid = {
0: ((slice(0, 1),), 0),
1: ((slice(4, 5),), 0),
2: ((slice(1, 2),), 0),
3: ((slice(5, 6),), 0),
4: ((slice(2, 3),), 0),
5: ((slice(6, 7),), 0),
6: ((slice(3, 4),), 0),
7: ((slice(7, 8),), 0),
8: ((slice(8, 9),), 0),
9: ((slice(12, 13),), 0),
10: ((slice(9, 10),), 0),
11: ((slice(13, 14),), 0),
12: ((slice(10, 11),), 0),
13: ((slice(14, 15),), 0),
14: ((slice(11, 12),), 0),
15: ((slice(15, 16),), 0),
}
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
f = pjit.pjit(lambda x: x,
in_axis_resources=pjit.FROM_GDA,
out_axis_resources=mesh_axes)
out = f(gda1)
for s in out.local_shards:
device_id = s.device.id
expected_index = expected_idx_rid[device_id][0]
expected_replica_id = expected_idx_rid[device_id][1]
self.assertEqual(s.index, expected_index)
self.assertEqual(s.replica_id, expected_replica_id)
self.assertEqual(s.data.shape, (1,))
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[expected_index])
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh_2d(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
global_input_shape = (16, 2)
mesh_axes = experimental.PartitionSpec("x", "y")
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
# device_id -> (index, replica_id)
expected_idx_rid = {
0: ((slice(0, 2), slice(0, 1)), 0),
1: ((slice(4, 6), slice(0, 1)), 0),
2: ((slice(0, 2), slice(1, 2)), 0),
3: ((slice(4, 6), slice(1, 2)), 0),
4: ((slice(2, 4), slice(0, 1)), 0),
5: ((slice(6, 8), slice(0, 1)), 0),
6: ((slice(2, 4), slice(1, 2)), 0),
7: ((slice(6, 8), slice(1, 2)), 0),
8: ((slice(8, 10), slice(0, 1)), 0),
9: ((slice(12, 14), slice(0, 1)), 0),
10: ((slice(8, 10), slice(1, 2)), 0),
11: ((slice(12, 14), slice(1, 2)), 0),
12: ((slice(10, 12), slice(0, 1)), 0),
13: ((slice(14, 16), slice(0, 1)), 0),
14: ((slice(10, 12), slice(1, 2)), 0),
15: ((slice(14, 16), slice(1, 2)), 0),
}
with global_mesh:
f = pjit.pjit(lambda x: x,
in_axis_resources=pjit.FROM_GDA,
out_axis_resources=mesh_axes)
out = f(gda1)
for s in out.local_shards:
device_id = s.device.id
expected_index = expected_idx_rid[device_id][0]
expected_replica_id = expected_idx_rid[device_id][1]
self.assertEqual(s.index, expected_index)
self.assertEqual(s.replica_id, expected_replica_id)
self.assertEqual(s.data.shape, (2, 1))
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[expected_index])
with global_mesh:
f = pjit.pjit(lambda x: x,
in_axis_resources=experimental.PartitionSpec(None),
out_axis_resources=mesh_axes)
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
with global_mesh:
f = pjit.pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=mesh_axes)
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, experimental.PartitionSpec(None), cb)
with global_mesh:
f = pjit.pjit(lambda x, y: (x, y),
in_axis_resources=(None, None),
out_axis_resources=(mesh_axes, mesh_axes))
# Fully replicated values + GDA allows a non-contiguous mesh.
out1, out2 = f(global_input_data, gda2)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh_2d_aot(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
global_input_shape = (8, 2)
mesh_axes = experimental.PartitionSpec("x", "y")
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes,
lambda idx: global_input_data[idx])
with global_mesh:
f = pjit.pjit(lambda x, y: (x, y),
in_axis_resources=experimental.PartitionSpec("x", "y"),
out_axis_resources=experimental.PartitionSpec("x", "y"))
inp_aval = jax.ShapedArray((8, 2), jnp.int32)
# `ShapedArray` is considered global when lowered and compiled.
# Hence it can bypass the contiguous mesh restriction.
compiled = f.lower(inp_aval, gda1, _global_avals=True).compile()
out1, out2 = compiled(gda1, gda1)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 2))
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 2))
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_eval_shape(self):
jax.distributed.initialize()
with jtu.create_global_mesh((16,), ("x")):
@functools.partial(pjit.pjit,
in_axis_resources=experimental.PartitionSpec(None),
out_axis_resources=experimental.PartitionSpec("x"))
def f():
return jnp.zeros([32, 10])
self.assertEqual(f().shape, (32, 10))
self.assertEqual(jax.eval_shape(f).shape, (32, 10))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())