Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests

PiperOrigin-RevId: 516881635
This commit is contained in:
Yash Katariya 2023-03-15 11:28:25 -07:00 committed by jax authors
parent 01dcd3a3fc
commit 88584290aa
7 changed files with 54 additions and 1123 deletions

View File

@ -17,7 +17,6 @@
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load(
"//jaxlib:jax.bzl",
"global_device_array_visibility",
"jax_extra_deps",
"jax_internal_packages",
"jax_test_util_visibility",
@ -511,6 +510,6 @@ pytype_library(
pytype_library(
name = "global_device_array",
srcs = ["experimental/global_device_array.py"],
visibility = [":internal"] + global_device_array_visibility,
visibility = [":internal"],
deps = [":jax"],
)

View File

@ -38,8 +38,6 @@ jax_internal_packages = []
jax_test_util_visibility = []
loops_visibility = []
global_device_array_visibility = []
def py_deps(_package):
"""Returns the Bazel deps for Python package `package`."""

View File

@ -96,7 +96,6 @@ py_test(
tags = ["manual"],
deps = [
"//jax",
"//jax:global_device_array",
"//jax:test_util",
] + py_deps("portpicker"),
)
@ -183,7 +182,6 @@ jax_test(
},
tags = ["multiaccelerator"],
deps = [
"//jax:global_device_array",
"//jax:maps",
],
)
@ -203,16 +201,6 @@ jax_test(
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
"//jax:global_device_array",
],
)
jax_test(
name = "global_device_array_test",
srcs = ["global_device_array_test.py"],
tags = ["multiaccelerator"],
deps = [
"//jax:global_device_array",
],
)

View File

@ -1,403 +0,0 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""
import math
import unittest
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import core
from jax._src import test_util as jtu
from jax._src.util import safe_zip
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
import jax.experimental.global_device_array as gda_lib
from jax.experimental.global_device_array import GlobalDeviceArray, get_shard_indices
from jax.config import config
config.parse_flags_with_absl()
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(math.prod(global_shape)).reshape(global_shape)
return GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 2)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.addressable_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
self.assertIsInstance(gda.sharding, jax.sharding.NamedSharding)
self.assertArraysEqual(gda.addressable_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in gda.addressable_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
for s in gda.addressable_shards:
self.assertEqual(s.data.aval,
core.ShapedArray(expected_shard_shape, s.data.dtype))
for g, l in safe_zip(gda.global_shards, gda.addressable_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertEqual(g.data.aval, l.data.aval)
self.assertArraysEqual(g.data, l.data)
@parameterized.named_parameters(
("mesh_x_y_z", P("x", "y", "z"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
(4, 2, 1),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_xy_z", P(("x", "y"), "z"),
((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
(2, 2, 2),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_z", P("z"),
((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
(4, 4, 2),
[0, 0, 1, 1, 2, 2, 3, 3]),
)
def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
global_input_shape = (8, 4, 2)
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 3)
self.assertEqual(gda.size, 64)
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.addressable_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.addressable_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
@parameterized.named_parameters(
("mesh_x", P("x"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2),), (slice(2, 4),)),
(2,),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_none", P(),
((slice(None),), (slice(None),)),
(16,),
[0, 1, 2, 3, 4, 5, 6, 7]),
)
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (16,)
global_input_data = np.arange(math.prod(global_input_shape)).reshape(-1)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.addressable_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.addressable_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
def test_gda_shape_0_1d_mesh(self):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (0,)
mesh_axes = P(None)
def cb(index):
return np.array([])
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 0)
for i, s in enumerate(gda.addressable_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(np.asarray(s.data), np.array([]))
self.assertEqual(gda.dtype, np.float32)
self.assertEqual(
gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
(0,))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
(4, 1),
[0, 0, 0, 0]),
)
def test_gda_subset_devices(self, mesh_axes, expected_index,
expected_shard_shape, expected_replica_ids):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.addressable_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.addressable_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
for g, l in safe_zip(gda.global_shards, gda.addressable_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertArraysEqual(g.data, l.data)
def test_gda_batched_callback(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(indices):
self.assertEqual(len(indices), len(global_mesh.local_devices))
return [global_input_data[index] for index in indices]
gda = GlobalDeviceArray.from_batched_callback(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(np.asarray(gda.addressable_data(0)),
expected_first_shard_value)
expected_second_shard_value = np.array([[2, 3]])
self.assertArraysEqual(np.asarray(gda.addressable_data(1)),
expected_second_shard_value)
def test_gda_batched_callback_with_devices(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x')
global_input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(cb_inp):
self.assertLen(cb_inp, 4)
dbs = []
for inp in cb_inp:
index, devices = inp
self.assertLen(devices, 2)
array = global_input_data[index]
dbs.extend([jax.device_put(array, device) for device in devices])
return dbs
gda = GlobalDeviceArray.from_batched_callback_with_devices(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(np.asarray(gda.addressable_data(0)),
expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(np.asarray(gda.addressable_data(1)),
expected_second_shard_value)
def test_gda_str_repr(self):
if jax.config.jax_array:
self.skipTest('jax.Array repr already has a test')
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
self.assertEqual(str(gda),
'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
self.assertEqual(
repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
"global_mesh_shape={'x': 4, 'y': 2}, "
"mesh_axes=PartitionSpec(('x', 'y'),))"))
def test_gda_equality_raises_not_implemented(self):
if jax.config.jax_array:
self.skipTest('jax.Array has __eq__.')
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None,)
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
input_gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
same_input_gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with self.assertRaisesRegex(NotImplementedError,
'GlobalDeviceArray equality is intentionally unimplemented.'):
input_gda == same_input_gda
def test_mesh_hash(self):
global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y'))
global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y'))
self.assertNotEqual(hash(global_mesh1), hash(global_mesh2))
self.assertEqual(hash(global_mesh1), hash(global_mesh3))
def test_device_mismatch(self):
devices = jax.devices()
if len(devices) < 8:
raise unittest.SkipTest("Test requires 8 global devices.")
mesh_devices = np.array([[devices[0], devices[2]],
[devices[3], devices[1]],
[devices[4], devices[6]],
[devices[7], devices[5]]])
global_mesh = Mesh(mesh_devices, ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
dbs = [
jax.device_put(global_input_data[indices[d]], d)
for d in jax.local_devices()
]
with self.assertRaisesRegex(
ValueError,
'The `global_mesh.local_devices` and `device_buffers` device order'):
GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def test_gda_block_until_ready(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
self.assertIs(gda.block_until_ready(), gda)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_gda_value(self, mesh_axes):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
gda, global_data = create_gda(input_shape, global_mesh, mesh_axes)
self.assertArraysEqual(gda._value, global_data)
def test_gda_delete(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
gda, _ = create_gda(input_shape, global_mesh, P("x", "y"))
gda._check_if_deleted()
gda.delete()
if jax.config.jax_array:
arr_type = 'Array'
else:
arr_type = 'GlobalDeviceArray'
with self.assertRaisesRegex(RuntimeError,
f"{arr_type} has been deleted."):
gda._check_if_deleted()
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -33,7 +33,6 @@ from jax._src import distributed
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src import util
from jax.experimental import global_device_array
from jax.experimental import pjit
try:
@ -338,43 +337,37 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
return global_input_data[index]
mesh_axes1 = experimental.PartitionSpec("x", "y")
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes1, cb)
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes1), cb)
mesh_axes2 = experimental.PartitionSpec("x")
gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes2, cb)
gda2 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes2), cb)
mesh_axes3 = experimental.PartitionSpec(("x", "y"))
gda3 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)
gda3 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes3), cb)
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
@functools.partial(
pjit.pjit,
# `FROM_GDA` will be replicated for all the inputs.
in_shardings=pjit.FROM_GDA,
out_shardings=(mesh_axes1, None, mesh_axes2))
def f(x, y, z):
return x @ x.T, y, z
out1, out2, out3 = f(gda1, gda2, gda3)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (16, 16))
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 8))
self.assertDictEqual(out1.mesh.shape, {"x": 8, "y": 2})
expected_matrix_mul = global_input_data @ global_input_data.T
for s in out1.addressable_shards:
np.testing.assert_array_equal(np.asarray(s.data),
expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (16, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, (16, 2))
for s in out2.addressable_shards:
np.testing.assert_array_equal(np.asarray(s.data), global_input_data)
self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
self.assertEqual(out3.shape, (16, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, (2, 2))
for s in out3.addressable_shards:
@ -403,8 +396,8 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
def cb(index):
return global_input_data[index]
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), cb)
# device_id -> (index, replica_id)
expected_idx_rid = {
@ -454,8 +447,8 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
def cb(index):
return global_input_data[index]
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), cb)
# device_id -> (index, replica_id)
expected_idx_rid = {
@ -501,16 +494,14 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
)
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
with global_mesh:
f = pjit.pjit(lambda x: x, in_shardings=None, out_shardings=mesh_axes)
# Fully replicated values allows a non-contiguous mesh.
out = f(global_input_data)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, experimental.PartitionSpec(None), cb)
gda2 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, experimental.PartitionSpec(None)), cb)
with global_mesh:
f = pjit.pjit(
@ -520,8 +511,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
)
# Fully replicated values + GDA allows a non-contiguous mesh.
out1, out2 = f(global_input_data, gda2)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
@ -532,8 +521,8 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
mesh_axes = experimental.PartitionSpec("x", "y")
global_input_data = np.arange(
util.prod(global_input_shape)).reshape(global_input_shape)
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes,
gda1 = jax.make_array_from_callback(
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes),
lambda idx: global_input_data[idx])
with global_mesh:
@ -547,9 +536,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
# Hence it can bypass the contiguous mesh restriction.
compiled = f.lower(inp_aval, gda1).compile()
out1, out2 = compiled(gda1, gda1)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 2))
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 2))
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`

View File

@ -32,7 +32,7 @@ import jax.numpy as jnp
from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.config import parallel_functions_output_gda, jax_array
from jax._src.config import jax_array
from jax import dtypes
from jax import stages
from jax.errors import JAXTypeError
@ -41,14 +41,13 @@ from jax.lax import with_sharding_constraint
from jax import prng
from jax.sharding import PartitionSpec as P
from jax.experimental.maps import xmap
from jax.experimental import global_device_array
from jax.experimental import multihost_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import Sharding
from jax._src.sharding_impls import NamedSharding, GSPMDSharding
import jax._src.pjit as pjit_lib
from jax._src.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
from jax._src.pjit import (pjit, pjit_p, AUTO)
from jax._src import mesh
from jax._src.interpreters import pxla
from jax.interpreters import mlir
@ -84,19 +83,6 @@ def tearDownModule():
jtu.restore_spmd_lowering_flag()
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None,
dtype=np.float32):
if global_data is None:
global_data = np.arange(
math.prod(global_shape), dtype=dtype).reshape(global_shape)
if isinstance(mesh_axes, Sharding):
mesh_axes = mesh_axes.spec
return global_device_array.GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
dtype=np.float32):
if global_data is None:
@ -1194,400 +1180,6 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertArraysEqual(result0, result1)
self.assertArraysEqual(result1, result2)
@jtu.pytest_mark_if_available('multiaccelerator')
class GDAPjitTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
def test_pjit_gda_single_output(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with parallel_functions_output_gda(True):
with global_mesh:
@partial(pjit, in_shardings=FROM_GDA, out_shardings=P('x', 'y'))
def f(x):
return x @ x.T
expected_matrix_mul = input_data @ input_data.T
out = f(gda_obj)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 4))
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
for s in out.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
out2 = f(out)
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
with self.assertRaisesRegex(
ValueError, ('For a non-GDA input, the corresponding resource in '
'in_axis_resources cannot be `pjit.FROM_GDA`.')):
f(input_data)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_multi_input_multi_output(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
mesh_axes1 = P('x', 'y')
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes1, cb)
mesh_axes2 = P('x')
gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes2, cb)
mesh_axes3 = P(('x', 'y'))
gda3 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)
mesh_axes4 = P(None)
gda4 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes4, cb)
with parallel_functions_output_gda(True):
@partial(
pjit,
# `FROM_GDA` will be replicated for all the inputs.
in_shardings=FROM_GDA,
out_shardings=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
def f(x, y, z, a):
return x @ x.T, y, z, a
out1, out2, out3, out4 = f(gda1, gda2, gda3, gda4)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 4))
self.assertEqual(out1.addressable_shards[0].index, (slice(0, 2), slice(0, 4)))
self.assertEqual(out1.addressable_shards[1].index, (slice(0, 2), slice(4, 8)))
self.assertListEqual([s.replica_id for s in out1.addressable_shards],
[0, 0, 0, 0, 0, 0, 0, 0])
expected_matrix_mul = input_data @ input_data.T
for s in out1.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].index, (slice(None), slice(None)))
self.assertEqual(out2.addressable_shards[1].index, (slice(None), slice(None)))
self.assertListEqual([s.replica_id for s in out2.addressable_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
for s in out2.addressable_shards:
self.assertArraysEqual(s.data, input_data)
self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, (2, 2))
self.assertEqual(out3.addressable_shards[0].index, (slice(0, 2), slice(None)))
self.assertEqual(out3.addressable_shards[1].index, (slice(0, 2), slice(None)))
self.assertListEqual([s.replica_id for s in out3.addressable_shards],
[0, 1, 0, 1, 0, 1, 0, 1])
for s in out3.addressable_shards:
self.assertArraysEqual(s.data, input_data[s.index])
self.assertIsInstance(out4, global_device_array.GlobalDeviceArray)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.addressable_shards[0].data.shape, (1, 2))
self.assertEqual(out4.addressable_shards[0].index, (slice(0, 1), slice(None)))
self.assertEqual(out4.addressable_shards[1].index, (slice(1, 2), slice(None)))
self.assertListEqual([s.replica_id for s in out4.addressable_shards],
[0, 0, 0, 0, 0, 0, 0, 0])
for s in out4.addressable_shards:
self.assertArraysEqual(s.data, input_data[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_mixed_inputs(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with parallel_functions_output_gda(True):
@partial(pjit,
in_shardings=(FROM_GDA, P('x', 'y')),
out_shardings=(P('x', 'y'), P(('x', 'y'))))
def f(x, y):
return x @ x.T, y @ y.T
expected_matrix_mul = input_data @ input_data.T
out1, out2 = f(gda_obj, input_data)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 4))
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 8))
self.assertEqual(out2.addressable_shards[0].data.shape, (1, 8))
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_non_gda_inputs(self):
input_shape = (8, 2)
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
with parallel_functions_output_gda(True):
@partial(pjit,
in_shardings=(None, P('x', 'y')),
out_shardings=(P('x', 'y'), P(('x', 'y'))))
def f(x, y):
return x @ x.T, y @ y.T
expected_matrix_mul = input_data @ input_data.T
out1, out2 = f(input_data, input_data)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 4))
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 8))
self.assertEqual(out2.addressable_shards[0].data.shape, (1, 8))
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_pjit_gda_mesh_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
global_input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32
).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with self.assertRaisesRegex(ValueError,
"Pjit's mesh and GDA's mesh should be equal."):
@partial(pjit, in_shardings=FROM_GDA, out_shardings=P('x', 'y'))
def f(x):
return x
f(gda_obj)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_wrong_resource_for_gda_input(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x')
global_input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32
).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with self.assertRaisesRegex(
ValueError,
r"Got an input GDA to pjit with different partitioning than specified "
r'in the in_axis_resources argument to pjit. The partitioning must match, or '
r'use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. '
r"Got GDA sharding.*PartitionSpec\('x',\).*and "
r"pjit sharding.*PartitionSpec\('x', 'y'\).*"):
@partial(pjit, in_shardings=P('x', 'y'), out_shardings=P('x', 'y'))
def f(x):
return x
f(gda_obj)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_caching(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
def cb(index):
return input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
input_shape, global_mesh, mesh_axes, cb)
@partial(pjit, in_shardings=mesh_axes, out_shardings=P('x', 'y'))
def f(x, y):
return x @ y.T
before_lower_cache = pjit_lib._pjit_lower_cached.cache_info()
f(gda_obj, gda_obj)
after_lower_cache1 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits)
self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses)
f(gda_obj, gda_obj)
after_lower_cache2 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits)
self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses)
f(input_data, input_data)
after_lower_cache3 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits)
self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses)
f(gda_obj, input_data)
after_lower_cache4 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits)
self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_partition_spec_mismatch_semantically_equivalent(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None)
global_input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32
).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
with parallel_functions_output_gda(True):
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
@partial(pjit, in_shardings=P(None), out_shardings=P(None))
def f(x):
return x
output_gda = f(gda_obj)
# Ensure output_gda.mesh_axes = P() is matched with P(None).
self.assertEqual(output_gda.mesh_axes, ())
# P(None) is in_axis_resources.
f(output_gda)
def test_from_gda_duplicates(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_gda, _ = create_gda(global_input_shape, global_mesh, mesh_axes)
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
# pickling in_axis_resources and sending to other processes). Make sure this
# this doesn't cause an error to avoid user confusion.
from_gda_dup = pjit_lib._FromGdaSingleton()
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
pjit(lambda x: x, in_shardings=from_gda_dup, out_shardings=None)(
input_gda
)
def test_no_recompilation_due_to_in_axis_resources(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None,)
input_gda, _ = create_gda(global_input_shape, global_mesh, mesh_axes)
with parallel_functions_output_gda(True):
@partial(pjit, in_shardings=mesh_axes, out_shardings=mesh_axes)
def f(x):
return x
with global_mesh:
out_gda = f(input_gda)
self.assertEqual(out_gda.mesh_axes, ())
before_cache = pjit_lib._pjit_lower_cached.cache_info()
f(out_gda)
after_cache = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)
def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None)
global_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
with parallel_functions_output_gda(True):
f = pjit(lambda x: x, in_shardings=mesh_axes, out_shardings=mesh_axes)
with global_mesh:
out_gda = f(global_data)
self.assertEqual(out_gda.mesh_axes, ())
before_cache = pjit_lib._pjit_lower_cached.cache_info()
f(out_gda)
after_cache = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)
def test_pjit_gda_aot_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_gda, _ = create_gda(global_input_shape, global_mesh, P('x', 'y'))
with global_mesh:
f = pjit(lambda x: x, in_shardings=P('x'), out_shardings=P('x'))
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
with self.assertRaisesRegex(
ValueError, "GDA sharding does not match the input sharding."):
compiled(input_gda)
def test_pjit_gda_same_sharding_aot(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
g1, _ = create_gda(global_input_shape, global_mesh, P(None,))
with global_mesh:
f = pjit(lambda x: x, in_shardings=P(None), out_shardings=P('x'))
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
compiled(g1) # no error
@parallel_functions_output_gda(True)
def test_globally_sharded_key_array_8x4_multi_device(self):
input_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
spec = P('x', 'y')
seeds, _ = create_gda(input_shape, mesh, spec, dtype=np.uint32)
with mesh:
@partial(pjit, in_shardings=spec, out_shardings=spec)
def make_keys(seeds):
make_key = partial(prng.seed_with_impl, prng.threefry_prng_impl)
return make_key(seeds)
out = make_keys(seeds)
self.assertIsInstance(out, jax.random.KeyArray)
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
@jtu.pytest_mark_if_available('multiaccelerator')
class AutoShardingPjitTest(jtu.JaxTestCase):
@ -1601,34 +1193,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
config.update('jax_array', self.jax_array_enabled)
super().tearDown()
@parameterized.named_parameters(
('2d_gda', (4, 2), (4, 2), ('x', 'y')),
# TODO(b/226977360): Support 3D mesh shape for example (2, 2, 2).
('3d_gda', (1, 4, 2), (2, 4, 8, 4), ('x', 'y', 'z')),
('1d_gda', (8,), (8, 2), ('x')),
)
def test_pjit_arr_auto_sharding_gda(self, mesh_shape, global_input_shape,
mesh_axis_names):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
if config.jax_array:
raise unittest.SkipTest('GDA and Array cannot be together.')
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with parallel_functions_output_gda(True):
with global_mesh:
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
out = compiled(*inputs)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertArraysEqual(out._value, input_data)
@parameterized.named_parameters(
('2d_array', (4, 2), (4, 2), ('x', 'y')),
# TODO(b/226977360): Support 3D mesh shape for example (2, 2, 2).
@ -1655,37 +1219,29 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out._value, input_data)
@parameterized.named_parameters(
('gda', parallel_functions_output_gda, create_gda),
('array', jax_array, create_array),
)
def test_xla_arr_sharding_mismatch(self, ctx, create_fun):
def test_xla_arr_sharding_mismatch(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
if not jax.config.jax_array:
self.skipTest("Test requires jax.Array")
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (4, 2)
input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with ctx(True):
with global_mesh:
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
with global_mesh:
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
different_pspec = (P('y', 'x')
if compiled.input_shardings[0][0].spec == P(('x',), ('y',))
else P('x', 'y'))
arr, _ = create_fun(global_input_shape, global_mesh, different_pspec,
different_pspec = (P('y', 'x')
if compiled.input_shardings[0][0].spec == P(('x',), ('y',))
else P('x', 'y'))
arr, _ = create_array(global_input_shape, global_mesh, different_pspec,
input_data)
if jax.config.jax_array:
arr_type = 'Array'
else:
arr_type = 'GDA'
with self.assertRaisesRegex(
ValueError,
f"{arr_type} sharding does not match the input sharding."):
compiled(arr)
with self.assertRaisesRegex(
ValueError, "Array sharding does not match the input sharding."):
compiled(arr)
def test_gda_auto_shardings_len(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
@ -1702,41 +1258,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
self.assertLen(compiled.output_shardings, 3)
self.assertLen(compiled.input_shardings[0], 3)
@parameterized.named_parameters(
('3d_gda', (1, 1, 2), ('x', 'y', 'z'), P(('x', 'y', 'z'))),
('2d_gda', (4, 2), ('x', 'y'), P('y', 'x')),
('1d_gda', (8,), ('x'), P('x')),
)
def test_pjit_arr_partial_auto_sharding_gda(
self, mesh_shape, mesh_axis_names, pspec):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
if config.jax_array:
raise unittest.SkipTest('GDA and Array cannot be together.')
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
global_input_shape = (8, 4)
input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
in_resource = pspec
with parallel_functions_output_gda(True):
with global_mesh:
f = pjit(
lambda x, y: (x, y),
in_shardings=(in_resource, AUTO),
out_shardings=AUTO,
)
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp).compile()
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
out1, out2 = compiled(*inputs)
for o in [out1, out2]:
self.assertIsInstance(o, global_device_array.GlobalDeviceArray)
self.assertArraysEqual(o._value, input_data)
@parameterized.named_parameters(
('3d_array', (1, 1, 2), ('x', 'y', 'z'), P(('x', 'y', 'z'))),
('2d_array', (4, 2), ('x', 'y'), P('y', 'x')),

View File

@ -35,7 +35,6 @@ from jax import lax
from jax._src import core
from jax._src.core import NamedShape
from jax.experimental import maps
from jax.experimental import global_device_array
from jax._src import array
from jax._src.sharding_impls import NamedSharding
from jax.experimental.pjit import pjit, with_sharding_constraint
@ -46,7 +45,7 @@ from jax._src import config as jax_config
from jax._src.nn import initializers as nn_initializers
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.util import unzip2, safe_zip
from jax._src.util import unzip2
from jax._src.lax import parallel as lax_parallel
from jax._src.lax.parallel import pgather
from jax.interpreters import batching, pxla
@ -349,24 +348,26 @@ class XMapTest(XMapTestCase):
v = jnp.arange(np.prod(vshape)).reshape(vshape)
zxy = fxy(v)
if config.jax_array:
zxy_sharding_spec = global_device_array._get_sharding_spec(
zxy.shape, zxy.sharding.mesh, zxy.sharding.spec)
zxy_op_sharding = zxy.sharding._to_xla_op_sharding(zxy.ndim)
self.assertListEqual(zxy_op_sharding.tile_assignment_dimensions, [1, 4])
self.assertListEqual(zxy_op_sharding.tile_assignment_devices, [0, 1, 2, 3])
else:
zxy_sharding_spec = zxy.sharding_spec
self.assertEqual(
zxy_sharding_spec,
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
self.assertEqual(
zxy_sharding_spec,
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
zyx = fyx(v)
if config.jax_array:
zyx_sharding_spec = global_device_array._get_sharding_spec(
zyx.shape, zyx.sharding.mesh, zyx.sharding.spec)
zyx_op_sharding = zyx.sharding._to_xla_op_sharding(zyx.ndim)
self.assertListEqual(zyx_op_sharding.tile_assignment_dimensions, [1, 4])
self.assertListEqual(zyx_op_sharding.tile_assignment_devices, [0, 2, 1, 3])
else:
zyx_sharding_spec = zyx.sharding_spec
self.assertEqual(
zyx_sharding_spec,
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
self.assertEqual(
zyx_sharding_spec,
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
@jtu.with_mesh([('x', 2), ('y', 2)])
def testSkipFirstMeshDim(self):
@ -442,15 +443,16 @@ class XMapTest(XMapTestCase):
(1, 2, 0)), (x * 2).sum((0, 1))))
if config.jax_array:
sharding_spec = global_device_array._get_sharding_spec(
y[0].shape, y[0].sharding.mesh, y[0].sharding.spec)
y_op_sharding = y[0].sharding._to_xla_op_sharding(y[0].ndim)
m_size = math.prod([2] + [2] * (len(mesh) - 2))
self.assertListEqual(y_op_sharding.tile_assignment_dimensions, [2, 1, 1, m_size])
else:
sharding_spec = y[0].sharding_spec
self.assertEqual(sharding_spec.sharding,
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) +
(pxla.Replicated(2),) * (len(mesh) - 2))
self.assertEqual(sharding_spec.sharding,
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) +
(pxla.Replicated(2),) * (len(mesh) - 2))
if config.experimental_xmap_spmd_lowering:
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
# Make sure that there are non-partial sharding specs in the HLO
@ -1105,167 +1107,6 @@ class NamedNNTest(XMapTestCase):
atol=1e-4, rtol=2e-2)
@jtu.pytest_mark_if_available('multiaccelerator')
class XMapGDATest(XMapTestCase):
def setUp(self):
super().setUp()
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_basic(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with jax_config.parallel_functions_output_gda(True):
f = maps.xmap(
lambda x: x,
in_axes=({0: "a", 1: "b"}),
out_axes=({0: "a", 1: "b"}),
axis_resources={"a": "x", "b": "y"})
out = f(gda_obj)
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
for s in out.addressable_shards:
self.assertArraysEqual(s.data, input_data[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_xmap_gda_mixed_inputs(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x')
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with jax_config.parallel_functions_output_gda(True):
f = maps.xmap(
lambda x, y: (x @ x.T, y @ y.T),
in_axes=({0: "a"}, ["c", ...]),
out_axes=({0: "a"}, ["c", ...]),
axis_resources={"a": "x", "c": "x"})
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
out1, out2 = f(gda_obj, input_data)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8,))
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8,))
self.assertEqual(out2.addressable_shards[0].data.shape, (2,))
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
for i, j in safe_zip(out1.addressable_shards, out2.addressable_shards):
self.assertArraysEqual(i.data, j.data)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_xmap_gda_double_input(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
gda_obj1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, P('x'), cb)
gda_obj2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, P('y'), cb)
with jax_config.parallel_functions_output_gda(True):
f = maps.xmap(
lambda x, y: (x @ x.T, y @ y.T),
in_axes=({0: "a"}, ["c", ...]),
out_axes=({0: "a"}, ["c", ...]),
axis_resources={"a": "x", "c": "y"})
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
out1, out2 = f(gda_obj1, gda_obj2)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8,))
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
for s in out1.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8,))
self.assertEqual(out2.addressable_shards[0].data.shape, (4,))
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
for s in out2.addressable_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_xmap_gda_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with jax_config.parallel_functions_output_gda(True):
f = maps.xmap(
lambda x: x @ x.T,
in_axes=({0: "a"}),
out_axes=({0: "a"}),
axis_resources={"a": "x"})
with self.assertRaisesRegex(
ValueError,
('Got an input GDA to xmap with different partitioning than '
'specified in xmap. The partitioning must match.')):
f(gda_obj)
def test_gda_from_pjit_with_xmap_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((8, 1), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
math.prod(global_input_shape)).reshape(global_input_shape)
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, lambda idx: input_data[idx])
with jax_config.parallel_functions_output_gda(True):
with global_mesh:
out = pjit(
lambda x: x, in_shardings=P('x', 'y'), out_shardings=P('x', 'y')
)(gda_obj)
xmap_out = maps.xmap(
lambda x: x,
in_axes=({0: "a", 1: "b"}),
out_axes=({0: "a", 1: "b"}),
axis_resources={"a": "x", "b": "y"})(out) # doesn't crash
self.assertArraysEqual(xmap_out, input_data)
@jtu.pytest_mark_if_available('multiaccelerator')
class XMapArrayTest(XMapTestCase):