rocm_jax/tests/global_device_array_test.py
2022-03-07 08:59:23 -08:00

362 lines
15 KiB
Python

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""
from absl.testing import absltest
from absl.testing import parameterized
import unittest
import numpy as np
import jax
from jax import core
from jax._src import test_util as jtu
from jax._src.util import prod, safe_zip
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import Mesh
import jax.experimental.global_device_array as gda_lib
from jax.experimental.global_device_array import GlobalDeviceArray, get_shard_indices
from jax.config import config
config.parse_flags_with_absl()
class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 2)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.mesh_axes, mesh_axes)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in gda.local_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
for s in gda.local_shards:
self.assertEqual(s.data.aval,
core.ShapedArray(expected_shard_shape, s.data.dtype))
for g, l in safe_zip(gda.global_shards, gda.local_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertEqual(g.data.aval, l.data.aval)
self.assertArraysEqual(g.data, l.data)
@parameterized.named_parameters(
("mesh_x_y_z", P("x", "y", "z"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
(4, 2, 1),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_xy_z", P(("x", "y"), "z"),
((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
(2, 2, 2),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_z", P("z"),
((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
(4, 4, 2),
[0, 0, 1, 1, 2, 2, 3, 3]),
)
def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
global_input_shape = (8, 4, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 3)
self.assertEqual(gda.size, 64)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
@parameterized.named_parameters(
("mesh_x", P("x"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2),), (slice(2, 4),)),
(2,),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_none", P(),
((slice(None),), (slice(None),)),
(16,),
[0, 1, 2, 3, 4, 5, 6, 7]),
)
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (16,)
global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
def test_gda_shape_0_1d_mesh(self):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (0,)
mesh_axes = P(None)
def cb(index):
return np.array([])
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 0)
for i, s in enumerate(gda.local_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertEqual(gda.dtype, np.float32)
self.assertEqual(
gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
(0,))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
(4, 1),
[0, 0, 0, 0]),
)
def test_gda_subset_devices(self, mesh_axes, expected_index,
expected_shard_shape, expected_replica_ids):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
for g, l in safe_zip(gda.global_shards, gda.local_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertArraysEqual(g.data, l.data)
def test_gda_batched_callback(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(indices):
self.assertEqual(len(indices), len(global_mesh.local_devices))
return [global_input_data[index] for index in indices]
gda = GlobalDeviceArray.from_batched_callback(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(gda.local_data(0).to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[2, 3]])
self.assertArraysEqual(gda.local_data(1).to_py(),
expected_second_shard_value)
def test_gda_batched_callback_with_devices(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(cb_inp):
self.assertLen(cb_inp, 4)
dbs = []
for inp in cb_inp:
index, devices = inp
self.assertLen(devices, 2)
array = global_input_data[index]
dbs.extend([jax.device_put(array, device) for device in devices])
return dbs
gda = GlobalDeviceArray.from_batched_callback_with_devices(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(0).to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(1).to_py(),
expected_second_shard_value)
def test_gda_str_repr(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
self.assertEqual(str(gda),
'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
self.assertEqual(
repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
"global_mesh_shape={'x': 4, 'y': 2}, "
"mesh_axes=PartitionSpec(('x', 'y'),))"))
def test_gda_equality_raises_not_implemented(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None,)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
input_gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
same_input_gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with self.assertRaisesRegex(NotImplementedError,
'GlobalDeviceArray equality is intentionally unimplemented.'):
input_gda == same_input_gda
def test_mesh_hash(self):
global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y'))
global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y'))
self.assertNotEqual(hash(global_mesh1), hash(global_mesh2))
self.assertEqual(hash(global_mesh1), hash(global_mesh3))
def test_device_mismatch(self):
devices = jax.devices()
if len(devices) < 8:
raise unittest.SkipTest("Test requires 8 global devices.")
mesh_devices = np.array([[devices[0], devices[2]],
[devices[3], devices[1]],
[devices[4], devices[6]],
[devices[7], devices[5]]])
global_mesh = Mesh(mesh_devices, ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
dbs = [
jax.device_put(global_input_data[indices[d]], d)
for d in jax.local_devices()
]
with self.assertRaisesRegex(
ValueError,
'The `global_mesh.local_devices` and `device_buffers` device order'):
GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def test_gda_block_until_ready(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
self.assertTrue(gda.block_until_ready() is gda)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())