mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
PiperOrigin-RevId: 516881635
This commit is contained in:
parent
01dcd3a3fc
commit
88584290aa
@ -17,7 +17,6 @@
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"global_device_array_visibility",
|
||||
"jax_extra_deps",
|
||||
"jax_internal_packages",
|
||||
"jax_test_util_visibility",
|
||||
@ -511,6 +510,6 @@ pytype_library(
|
||||
pytype_library(
|
||||
name = "global_device_array",
|
||||
srcs = ["experimental/global_device_array.py"],
|
||||
visibility = [":internal"] + global_device_array_visibility,
|
||||
visibility = [":internal"],
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
@ -38,8 +38,6 @@ jax_internal_packages = []
|
||||
jax_test_util_visibility = []
|
||||
loops_visibility = []
|
||||
|
||||
global_device_array_visibility = []
|
||||
|
||||
def py_deps(_package):
|
||||
"""Returns the Bazel deps for Python package `package`."""
|
||||
|
||||
|
12
tests/BUILD
12
tests/BUILD
@ -96,7 +96,6 @@ py_test(
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:global_device_array",
|
||||
"//jax:test_util",
|
||||
] + py_deps("portpicker"),
|
||||
)
|
||||
@ -183,7 +182,6 @@ jax_test(
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:global_device_array",
|
||||
"//jax:maps",
|
||||
],
|
||||
)
|
||||
@ -203,16 +201,6 @@ jax_test(
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
"//jax:global_device_array",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "global_device_array_test",
|
||||
srcs = ["global_device_array_test.py"],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:global_device_array",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,403 +0,0 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
|
||||
import math
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import safe_zip
|
||||
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.sharding import Mesh
|
||||
import jax.experimental.global_device_array as gda_lib
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray, get_shard_indices
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
|
||||
if global_data is None:
|
||||
global_data = np.arange(math.prod(global_shape)).reshape(global_shape)
|
||||
|
||||
return GlobalDeviceArray.from_callback(
|
||||
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
|
||||
|
||||
|
||||
class GDATest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
|
||||
(2, 1),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0], False),
|
||||
("mesh_x", P("x"),
|
||||
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
|
||||
(2, 2),
|
||||
[0, 1, 0, 1, 0, 1, 0, 1], False),
|
||||
("mesh_y", P("y"),
|
||||
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
|
||||
(4, 2),
|
||||
[0, 0, 1, 1, 2, 2, 3, 3], False),
|
||||
("mesh_none_y", P(None, "y"),
|
||||
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
|
||||
(8, 1),
|
||||
[0, 0, 1, 1, 2, 2, 3, 3], False),
|
||||
("mesh_xy", P(("x", "y")),
|
||||
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
|
||||
(1, 2),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0], False),
|
||||
("mesh_fully_replicated", P(),
|
||||
((slice(None), slice(None)), (slice(None), slice(None))),
|
||||
(8, 2),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7], True),
|
||||
)
|
||||
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids, expected_is_fully_replicated):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 2)
|
||||
self.assertEqual(gda.size, 16)
|
||||
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.addressable_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
|
||||
self.assertIsInstance(gda.sharding, jax.sharding.NamedSharding)
|
||||
self.assertArraysEqual(gda.addressable_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
|
||||
replica_ids = [i.replica_id for i in gda.addressable_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
self.assertListEqual([i.device.id for i in gda.addressable_shards],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
|
||||
for s in gda.addressable_shards:
|
||||
self.assertEqual(s.data.aval,
|
||||
core.ShapedArray(expected_shard_shape, s.data.dtype))
|
||||
for g, l in safe_zip(gda.global_shards, gda.addressable_shards):
|
||||
self.assertEqual(g.device, l.device)
|
||||
self.assertEqual(g.index, l.index)
|
||||
self.assertEqual(g.replica_id, l.replica_id)
|
||||
self.assertEqual(g.data.aval, l.data.aval)
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y_z", P("x", "y", "z"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
|
||||
(4, 2, 1),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("mesh_xy_z", P(("x", "y"), "z"),
|
||||
((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
|
||||
(2, 2, 2),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("mesh_z", P("z"),
|
||||
((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
|
||||
(4, 4, 2),
|
||||
[0, 0, 1, 1, 2, 2, 3, 3]),
|
||||
)
|
||||
def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids):
|
||||
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
global_input_shape = (8, 4, 2)
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 3)
|
||||
self.assertEqual(gda.size, 64)
|
||||
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.addressable_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
|
||||
self.assertArraysEqual(gda.addressable_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
|
||||
|
||||
replica_ids = [i.replica_id for i in gda.addressable_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x", P("x"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 2),), (slice(2, 4),)),
|
||||
(2,),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("mesh_none", P(),
|
||||
((slice(None),), (slice(None),)),
|
||||
(16,),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
)
|
||||
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids):
|
||||
global_mesh = jtu.create_global_mesh((8,), ('x'))
|
||||
global_input_shape = (16,)
|
||||
global_input_data = np.arange(math.prod(global_input_shape)).reshape(-1)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 1)
|
||||
self.assertEqual(gda.size, 16)
|
||||
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.addressable_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
|
||||
self.assertArraysEqual(gda.addressable_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
|
||||
replica_ids = [i.replica_id for i in gda.addressable_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
|
||||
def test_gda_shape_0_1d_mesh(self):
|
||||
global_mesh = jtu.create_global_mesh((8,), ('x'))
|
||||
global_input_shape = (0,)
|
||||
mesh_axes = P(None)
|
||||
def cb(index):
|
||||
return np.array([])
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 1)
|
||||
self.assertEqual(gda.size, 0)
|
||||
for i, s in enumerate(gda.addressable_shards):
|
||||
self.assertEqual(s.index, (slice(None),))
|
||||
self.assertEqual(s.replica_id, i)
|
||||
self.assertArraysEqual(np.asarray(s.data), np.array([]))
|
||||
self.assertEqual(gda.dtype, np.float32)
|
||||
self.assertEqual(
|
||||
gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
|
||||
(0,))
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
|
||||
(4, 1),
|
||||
[0, 0, 0, 0]),
|
||||
)
|
||||
def test_gda_subset_devices(self, mesh_axes, expected_index,
|
||||
expected_shard_shape, expected_replica_ids):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.addressable_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.addressable_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
self.assertEqual(gda.addressable_shards[1].index, expected_index[1])
|
||||
self.assertArraysEqual(gda.addressable_data(1),
|
||||
global_input_data[expected_index[1]])
|
||||
self.assertEqual(gda.addressable_data(0).shape, expected_shard_shape)
|
||||
replica_ids = [i.replica_id for i in gda.addressable_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
for g, l in safe_zip(gda.global_shards, gda.addressable_shards):
|
||||
self.assertEqual(g.device, l.device)
|
||||
self.assertEqual(g.index, l.index)
|
||||
self.assertEqual(g.replica_id, l.replica_id)
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
def test_gda_batched_callback(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(indices):
|
||||
self.assertEqual(len(indices), len(global_mesh.local_devices))
|
||||
return [global_input_data[index] for index in indices]
|
||||
|
||||
gda = GlobalDeviceArray.from_batched_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1]])
|
||||
self.assertArraysEqual(np.asarray(gda.addressable_data(0)),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[2, 3]])
|
||||
self.assertArraysEqual(np.asarray(gda.addressable_data(1)),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gda_batched_callback_with_devices(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x')
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
def cb(cb_inp):
|
||||
self.assertLen(cb_inp, 4)
|
||||
dbs = []
|
||||
for inp in cb_inp:
|
||||
index, devices = inp
|
||||
self.assertLen(devices, 2)
|
||||
array = global_input_data[index]
|
||||
dbs.extend([jax.device_put(array, device) for device in devices])
|
||||
return dbs
|
||||
|
||||
gda = GlobalDeviceArray.from_batched_callback_with_devices(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
|
||||
self.assertArraysEqual(np.asarray(gda.addressable_data(0)),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
|
||||
self.assertArraysEqual(np.asarray(gda.addressable_data(1)),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gda_str_repr(self):
|
||||
if jax.config.jax_array:
|
||||
self.skipTest('jax.Array repr already has a test')
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
gda = GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
self.assertEqual(str(gda),
|
||||
'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
|
||||
self.assertEqual(
|
||||
repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
|
||||
"global_mesh_shape={'x': 4, 'y': 2}, "
|
||||
"mesh_axes=PartitionSpec(('x', 'y'),))"))
|
||||
|
||||
def test_gda_equality_raises_not_implemented(self):
|
||||
if jax.config.jax_array:
|
||||
self.skipTest('jax.Array has __eq__.')
|
||||
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None,)
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
input_gda = GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
same_input_gda = GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
'GlobalDeviceArray equality is intentionally unimplemented.'):
|
||||
input_gda == same_input_gda
|
||||
|
||||
def test_mesh_hash(self):
|
||||
global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
||||
global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y'))
|
||||
|
||||
global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
||||
self.assertNotEqual(hash(global_mesh1), hash(global_mesh2))
|
||||
self.assertEqual(hash(global_mesh1), hash(global_mesh3))
|
||||
|
||||
def test_device_mismatch(self):
|
||||
devices = jax.devices()
|
||||
if len(devices) < 8:
|
||||
raise unittest.SkipTest("Test requires 8 global devices.")
|
||||
mesh_devices = np.array([[devices[0], devices[2]],
|
||||
[devices[3], devices[1]],
|
||||
[devices[4], devices[6]],
|
||||
[devices[7], devices[5]]])
|
||||
global_mesh = Mesh(mesh_devices, ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
dbs = [
|
||||
jax.device_put(global_input_data[indices[d]], d)
|
||||
for d in jax.local_devices()
|
||||
]
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'The `global_mesh.local_devices` and `device_buffers` device order'):
|
||||
GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
|
||||
|
||||
def test_gda_block_until_ready(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
self.assertIs(gda.block_until_ready(), gda)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y")),
|
||||
("mesh_x", P("x")),
|
||||
("mesh_y", P("y")),
|
||||
("mesh_none_y", P(None, "y")),
|
||||
("mesh_xy", P(("x", "y"))),
|
||||
("mesh_fully_replicated", P()),
|
||||
)
|
||||
def test_gda_value(self, mesh_axes):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
gda, global_data = create_gda(input_shape, global_mesh, mesh_axes)
|
||||
self.assertArraysEqual(gda._value, global_data)
|
||||
|
||||
def test_gda_delete(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
gda, _ = create_gda(input_shape, global_mesh, P("x", "y"))
|
||||
gda._check_if_deleted()
|
||||
gda.delete()
|
||||
if jax.config.jax_array:
|
||||
arr_type = 'Array'
|
||||
else:
|
||||
arr_type = 'GlobalDeviceArray'
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
f"{arr_type} has been deleted."):
|
||||
gda._check_if_deleted()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -33,7 +33,6 @@ from jax._src import distributed
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import pjit
|
||||
|
||||
try:
|
||||
@ -338,43 +337,37 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
return global_input_data[index]
|
||||
|
||||
mesh_axes1 = experimental.PartitionSpec("x", "y")
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes1, cb)
|
||||
gda1 = jax.make_array_from_callback(
|
||||
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes1), cb)
|
||||
mesh_axes2 = experimental.PartitionSpec("x")
|
||||
gda2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes2, cb)
|
||||
gda2 = jax.make_array_from_callback(
|
||||
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes2), cb)
|
||||
mesh_axes3 = experimental.PartitionSpec(("x", "y"))
|
||||
gda3 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes3, cb)
|
||||
gda3 = jax.make_array_from_callback(
|
||||
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes3), cb)
|
||||
|
||||
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
|
||||
@functools.partial(
|
||||
pjit.pjit,
|
||||
# `FROM_GDA` will be replicated for all the inputs.
|
||||
in_shardings=pjit.FROM_GDA,
|
||||
out_shardings=(mesh_axes1, None, mesh_axes2))
|
||||
def f(x, y, z):
|
||||
return x @ x.T, y, z
|
||||
|
||||
out1, out2, out3 = f(gda1, gda2, gda3)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (16, 16))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 8))
|
||||
self.assertDictEqual(out1.mesh.shape, {"x": 8, "y": 2})
|
||||
expected_matrix_mul = global_input_data @ global_input_data.T
|
||||
for s in out1.addressable_shards:
|
||||
np.testing.assert_array_equal(np.asarray(s.data),
|
||||
expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (16, 2))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (16, 2))
|
||||
for s in out2.addressable_shards:
|
||||
np.testing.assert_array_equal(np.asarray(s.data), global_input_data)
|
||||
|
||||
self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out3.shape, (16, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].data.shape, (2, 2))
|
||||
for s in out3.addressable_shards:
|
||||
@ -403,8 +396,8 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
gda1 = jax.make_array_from_callback(
|
||||
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), cb)
|
||||
|
||||
# device_id -> (index, replica_id)
|
||||
expected_idx_rid = {
|
||||
@ -454,8 +447,8 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
gda1 = jax.make_array_from_callback(
|
||||
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes), cb)
|
||||
|
||||
# device_id -> (index, replica_id)
|
||||
expected_idx_rid = {
|
||||
@ -501,16 +494,14 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
)
|
||||
# Fully replicated values allows a non-contiguous mesh.
|
||||
out = f(global_input_data)
|
||||
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(lambda x: x, in_shardings=None, out_shardings=mesh_axes)
|
||||
# Fully replicated values allows a non-contiguous mesh.
|
||||
out = f(global_input_data)
|
||||
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
||||
|
||||
gda2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, experimental.PartitionSpec(None), cb)
|
||||
gda2 = jax.make_array_from_callback(
|
||||
global_input_shape, jax.sharding.NamedSharding(global_mesh, experimental.PartitionSpec(None)), cb)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit.pjit(
|
||||
@ -520,8 +511,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
)
|
||||
# Fully replicated values + GDA allows a non-contiguous mesh.
|
||||
out1, out2 = f(global_input_data, gda2)
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
@ -532,8 +521,8 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
mesh_axes = experimental.PartitionSpec("x", "y")
|
||||
global_input_data = np.arange(
|
||||
util.prod(global_input_shape)).reshape(global_input_shape)
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes,
|
||||
gda1 = jax.make_array_from_callback(
|
||||
global_input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes),
|
||||
lambda idx: global_input_data[idx])
|
||||
|
||||
with global_mesh:
|
||||
@ -547,9 +536,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
# Hence it can bypass the contiguous mesh restriction.
|
||||
compiled = f.lower(inp_aval, gda1).compile()
|
||||
out1, out2 = compiled(gda1, gda1)
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
|
@ -32,7 +32,7 @@ import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.config import parallel_functions_output_gda, jax_array
|
||||
from jax._src.config import jax_array
|
||||
from jax import dtypes
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
@ -41,14 +41,13 @@ from jax.lax import with_sharding_constraint
|
||||
from jax import prng
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import NamedSharding, GSPMDSharding
|
||||
import jax._src.pjit as pjit_lib
|
||||
from jax._src.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
|
||||
from jax._src.pjit import (pjit, pjit_p, AUTO)
|
||||
from jax._src import mesh
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
@ -84,19 +83,6 @@ def tearDownModule():
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
|
||||
|
||||
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None,
|
||||
dtype=np.float32):
|
||||
if global_data is None:
|
||||
global_data = np.arange(
|
||||
math.prod(global_shape), dtype=dtype).reshape(global_shape)
|
||||
|
||||
if isinstance(mesh_axes, Sharding):
|
||||
mesh_axes = mesh_axes.spec
|
||||
|
||||
return global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
|
||||
|
||||
|
||||
def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
|
||||
dtype=np.float32):
|
||||
if global_data is None:
|
||||
@ -1194,400 +1180,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertArraysEqual(result0, result1)
|
||||
self.assertArraysEqual(result1, result2)
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
|
||||
def test_pjit_gda_single_output(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
with global_mesh:
|
||||
@partial(pjit, in_shardings=FROM_GDA, out_shardings=P('x', 'y'))
|
||||
def f(x):
|
||||
return x @ x.T
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out = f(gda_obj)
|
||||
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 4))
|
||||
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
out2 = f(out)
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, ('For a non-GDA input, the corresponding resource in '
|
||||
'in_axis_resources cannot be `pjit.FROM_GDA`.')):
|
||||
f(input_data)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gda_multi_input_multi_output(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
mesh_axes1 = P('x', 'y')
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes1, cb)
|
||||
mesh_axes2 = P('x')
|
||||
gda2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes2, cb)
|
||||
mesh_axes3 = P(('x', 'y'))
|
||||
gda3 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes3, cb)
|
||||
mesh_axes4 = P(None)
|
||||
gda4 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes4, cb)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
@partial(
|
||||
pjit,
|
||||
# `FROM_GDA` will be replicated for all the inputs.
|
||||
in_shardings=FROM_GDA,
|
||||
out_shardings=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
|
||||
def f(x, y, z, a):
|
||||
return x @ x.T, y, z, a
|
||||
out1, out2, out3, out4 = f(gda1, gda2, gda3, gda4)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 4))
|
||||
self.assertEqual(out1.addressable_shards[0].index, (slice(0, 2), slice(0, 4)))
|
||||
self.assertEqual(out1.addressable_shards[1].index, (slice(0, 2), slice(4, 8)))
|
||||
self.assertListEqual([s.replica_id for s in out1.addressable_shards],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0])
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (8, 2))
|
||||
self.assertEqual(out2.addressable_shards[0].index, (slice(None), slice(None)))
|
||||
self.assertEqual(out2.addressable_shards[1].index, (slice(None), slice(None)))
|
||||
self.assertListEqual([s.replica_id for s in out2.addressable_shards],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7])
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data, input_data)
|
||||
|
||||
self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out3.shape, (8, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].index, (slice(0, 2), slice(None)))
|
||||
self.assertEqual(out3.addressable_shards[1].index, (slice(0, 2), slice(None)))
|
||||
self.assertListEqual([s.replica_id for s in out3.addressable_shards],
|
||||
[0, 1, 0, 1, 0, 1, 0, 1])
|
||||
for s in out3.addressable_shards:
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out4, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out4.shape, (8, 2))
|
||||
self.assertEqual(out4.addressable_shards[0].data.shape, (1, 2))
|
||||
self.assertEqual(out4.addressable_shards[0].index, (slice(0, 1), slice(None)))
|
||||
self.assertEqual(out4.addressable_shards[1].index, (slice(1, 2), slice(None)))
|
||||
self.assertListEqual([s.replica_id for s in out4.addressable_shards],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0])
|
||||
for s in out4.addressable_shards:
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gda_mixed_inputs(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
@partial(pjit,
|
||||
in_shardings=(FROM_GDA, P('x', 'y')),
|
||||
out_shardings=(P('x', 'y'), P(('x', 'y'))))
|
||||
def f(x, y):
|
||||
return x @ x.T, y @ y.T
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out1, out2 = f(gda_obj, input_data)
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 4))
|
||||
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 8))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (1, 8))
|
||||
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gda_non_gda_inputs(self):
|
||||
input_shape = (8, 2)
|
||||
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
@partial(pjit,
|
||||
in_shardings=(None, P('x', 'y')),
|
||||
out_shardings=(P('x', 'y'), P(('x', 'y'))))
|
||||
def f(x, y):
|
||||
return x @ x.T, y @ y.T
|
||||
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
out1, out2 = f(input_data, input_data)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2, 4))
|
||||
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 8))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (1, 8))
|
||||
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_pjit_gda_mesh_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32
|
||||
).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Pjit's mesh and GDA's mesh should be equal."):
|
||||
@partial(pjit, in_shardings=FROM_GDA, out_shardings=P('x', 'y'))
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
f(gda_obj)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gda_wrong_resource_for_gda_input(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x')
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32
|
||||
).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Got an input GDA to pjit with different partitioning than specified "
|
||||
r'in the in_axis_resources argument to pjit. The partitioning must match, or '
|
||||
r'use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. '
|
||||
r"Got GDA sharding.*PartitionSpec\('x',\).*and "
|
||||
r"pjit sharding.*PartitionSpec\('x', 'y'\).*"):
|
||||
@partial(pjit, in_shardings=P('x', 'y'), out_shardings=P('x', 'y'))
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
f(gda_obj)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gda_caching(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
@partial(pjit, in_shardings=mesh_axes, out_shardings=P('x', 'y'))
|
||||
def f(x, y):
|
||||
return x @ y.T
|
||||
|
||||
before_lower_cache = pjit_lib._pjit_lower_cached.cache_info()
|
||||
|
||||
f(gda_obj, gda_obj)
|
||||
after_lower_cache1 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits)
|
||||
self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses)
|
||||
|
||||
f(gda_obj, gda_obj)
|
||||
after_lower_cache2 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits)
|
||||
self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses)
|
||||
|
||||
f(input_data, input_data)
|
||||
after_lower_cache3 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits)
|
||||
self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses)
|
||||
|
||||
f(gda_obj, input_data)
|
||||
after_lower_cache4 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits)
|
||||
self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_partition_spec_mismatch_semantically_equivalent(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None)
|
||||
global_input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32
|
||||
).reshape(global_input_shape)
|
||||
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
@partial(pjit, in_shardings=P(None), out_shardings=P(None))
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
output_gda = f(gda_obj)
|
||||
# Ensure output_gda.mesh_axes = P() is matched with P(None).
|
||||
self.assertEqual(output_gda.mesh_axes, ())
|
||||
# P(None) is in_axis_resources.
|
||||
f(output_gda)
|
||||
|
||||
def test_from_gda_duplicates(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_gda, _ = create_gda(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
|
||||
# pickling in_axis_resources and sending to other processes). Make sure this
|
||||
# this doesn't cause an error to avoid user confusion.
|
||||
from_gda_dup = pjit_lib._FromGdaSingleton()
|
||||
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
pjit(lambda x: x, in_shardings=from_gda_dup, out_shardings=None)(
|
||||
input_gda
|
||||
)
|
||||
|
||||
def test_no_recompilation_due_to_in_axis_resources(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None,)
|
||||
input_gda, _ = create_gda(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
@partial(pjit, in_shardings=mesh_axes, out_shardings=mesh_axes)
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
with global_mesh:
|
||||
out_gda = f(input_gda)
|
||||
self.assertEqual(out_gda.mesh_axes, ())
|
||||
|
||||
before_cache = pjit_lib._pjit_lower_cached.cache_info()
|
||||
f(out_gda)
|
||||
after_cache = pjit_lib._pjit_lower_cached.cache_info()
|
||||
|
||||
self.assertEqual(before_cache.hits + 1, after_cache.hits)
|
||||
self.assertEqual(before_cache.misses, after_cache.misses)
|
||||
|
||||
def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None)
|
||||
global_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
f = pjit(lambda x: x, in_shardings=mesh_axes, out_shardings=mesh_axes)
|
||||
|
||||
with global_mesh:
|
||||
out_gda = f(global_data)
|
||||
self.assertEqual(out_gda.mesh_axes, ())
|
||||
|
||||
before_cache = pjit_lib._pjit_lower_cached.cache_info()
|
||||
f(out_gda)
|
||||
after_cache = pjit_lib._pjit_lower_cached.cache_info()
|
||||
|
||||
self.assertEqual(before_cache.hits + 1, after_cache.hits)
|
||||
self.assertEqual(before_cache.misses, after_cache.misses)
|
||||
|
||||
def test_pjit_gda_aot_sharding_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_gda, _ = create_gda(global_input_shape, global_mesh, P('x', 'y'))
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=P('x'), out_shardings=P('x'))
|
||||
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "GDA sharding does not match the input sharding."):
|
||||
compiled(input_gda)
|
||||
|
||||
def test_pjit_gda_same_sharding_aot(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
|
||||
g1, _ = create_gda(global_input_shape, global_mesh, P(None,))
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=P(None), out_shardings=P('x'))
|
||||
compiled = f.lower(core.ShapedArray(global_input_shape, jnp.float32)).compile()
|
||||
compiled(g1) # no error
|
||||
|
||||
@parallel_functions_output_gda(True)
|
||||
def test_globally_sharded_key_array_8x4_multi_device(self):
|
||||
input_shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
spec = P('x', 'y')
|
||||
|
||||
seeds, _ = create_gda(input_shape, mesh, spec, dtype=np.uint32)
|
||||
|
||||
with mesh:
|
||||
@partial(pjit, in_shardings=spec, out_shardings=spec)
|
||||
def make_keys(seeds):
|
||||
make_key = partial(prng.seed_with_impl, prng.threefry_prng_impl)
|
||||
return make_key(seeds)
|
||||
|
||||
out = make_keys(seeds)
|
||||
self.assertIsInstance(out, jax.random.KeyArray)
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
@ -1601,34 +1193,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
config.update('jax_array', self.jax_array_enabled)
|
||||
super().tearDown()
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('2d_gda', (4, 2), (4, 2), ('x', 'y')),
|
||||
# TODO(b/226977360): Support 3D mesh shape for example (2, 2, 2).
|
||||
('3d_gda', (1, 4, 2), (2, 4, 8, 4), ('x', 'y', 'z')),
|
||||
('1d_gda', (8,), (8, 2), ('x')),
|
||||
)
|
||||
def test_pjit_arr_auto_sharding_gda(self, mesh_shape, global_input_shape,
|
||||
mesh_axis_names):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('GDA and Array cannot be together.')
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
|
||||
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
out = compiled(*inputs)
|
||||
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('2d_array', (4, 2), (4, 2), ('x', 'y')),
|
||||
# TODO(b/226977360): Support 3D mesh shape for example (2, 2, 2).
|
||||
@ -1655,37 +1219,29 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('gda', parallel_functions_output_gda, create_gda),
|
||||
('array', jax_array, create_array),
|
||||
)
|
||||
def test_xla_arr_sharding_mismatch(self, ctx, create_fun):
|
||||
def test_xla_arr_sharding_mismatch(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
if not jax.config.jax_array:
|
||||
self.skipTest("Test requires jax.Array")
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (4, 2)
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with ctx(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
|
||||
different_pspec = (P('y', 'x')
|
||||
if compiled.input_shardings[0][0].spec == P(('x',), ('y',))
|
||||
else P('x', 'y'))
|
||||
arr, _ = create_fun(global_input_shape, global_mesh, different_pspec,
|
||||
different_pspec = (P('y', 'x')
|
||||
if compiled.input_shardings[0][0].spec == P(('x',), ('y',))
|
||||
else P('x', 'y'))
|
||||
arr, _ = create_array(global_input_shape, global_mesh, different_pspec,
|
||||
input_data)
|
||||
if jax.config.jax_array:
|
||||
arr_type = 'Array'
|
||||
else:
|
||||
arr_type = 'GDA'
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f"{arr_type} sharding does not match the input sharding."):
|
||||
compiled(arr)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Array sharding does not match the input sharding."):
|
||||
compiled(arr)
|
||||
|
||||
def test_gda_auto_shardings_len(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
@ -1702,41 +1258,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
self.assertLen(compiled.output_shardings, 3)
|
||||
self.assertLen(compiled.input_shardings[0], 3)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('3d_gda', (1, 1, 2), ('x', 'y', 'z'), P(('x', 'y', 'z'))),
|
||||
('2d_gda', (4, 2), ('x', 'y'), P('y', 'x')),
|
||||
('1d_gda', (8,), ('x'), P('x')),
|
||||
)
|
||||
def test_pjit_arr_partial_auto_sharding_gda(
|
||||
self, mesh_shape, mesh_axis_names, pspec):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('GDA and Array cannot be together.')
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
global_input_shape = (8, 4)
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
in_resource = pspec
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
with global_mesh:
|
||||
f = pjit(
|
||||
lambda x, y: (x, y),
|
||||
in_shardings=(in_resource, AUTO),
|
||||
out_shardings=AUTO,
|
||||
)
|
||||
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp).compile()
|
||||
inputs = [create_gda(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
out1, out2 = compiled(*inputs)
|
||||
for o in [out1, out2]:
|
||||
self.assertIsInstance(o, global_device_array.GlobalDeviceArray)
|
||||
self.assertArraysEqual(o._value, input_data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('3d_array', (1, 1, 2), ('x', 'y', 'z'), P(('x', 'y', 'z'))),
|
||||
('2d_array', (4, 2), ('x', 'y'), P('y', 'x')),
|
||||
|
@ -35,7 +35,6 @@ from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src.core import NamedShape
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import global_device_array
|
||||
from jax._src import array
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||
@ -46,7 +45,7 @@ from jax._src import config as jax_config
|
||||
from jax._src.nn import initializers as nn_initializers
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import unzip2, safe_zip
|
||||
from jax._src.util import unzip2
|
||||
from jax._src.lax import parallel as lax_parallel
|
||||
from jax._src.lax.parallel import pgather
|
||||
from jax.interpreters import batching, pxla
|
||||
@ -349,24 +348,26 @@ class XMapTest(XMapTestCase):
|
||||
v = jnp.arange(np.prod(vshape)).reshape(vshape)
|
||||
zxy = fxy(v)
|
||||
if config.jax_array:
|
||||
zxy_sharding_spec = global_device_array._get_sharding_spec(
|
||||
zxy.shape, zxy.sharding.mesh, zxy.sharding.spec)
|
||||
zxy_op_sharding = zxy.sharding._to_xla_op_sharding(zxy.ndim)
|
||||
self.assertListEqual(zxy_op_sharding.tile_assignment_dimensions, [1, 4])
|
||||
self.assertListEqual(zxy_op_sharding.tile_assignment_devices, [0, 1, 2, 3])
|
||||
else:
|
||||
zxy_sharding_spec = zxy.sharding_spec
|
||||
self.assertEqual(
|
||||
zxy_sharding_spec,
|
||||
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
||||
(pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
|
||||
self.assertEqual(
|
||||
zxy_sharding_spec,
|
||||
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
||||
(pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
|
||||
zyx = fyx(v)
|
||||
if config.jax_array:
|
||||
zyx_sharding_spec = global_device_array._get_sharding_spec(
|
||||
zyx.shape, zyx.sharding.mesh, zyx.sharding.spec)
|
||||
zyx_op_sharding = zyx.sharding._to_xla_op_sharding(zyx.ndim)
|
||||
self.assertListEqual(zyx_op_sharding.tile_assignment_dimensions, [1, 4])
|
||||
self.assertListEqual(zyx_op_sharding.tile_assignment_devices, [0, 2, 1, 3])
|
||||
else:
|
||||
zyx_sharding_spec = zyx.sharding_spec
|
||||
self.assertEqual(
|
||||
zyx_sharding_spec,
|
||||
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
||||
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
|
||||
self.assertEqual(
|
||||
zyx_sharding_spec,
|
||||
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
||||
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testSkipFirstMeshDim(self):
|
||||
@ -442,15 +443,16 @@ class XMapTest(XMapTestCase):
|
||||
(1, 2, 0)), (x * 2).sum((0, 1))))
|
||||
|
||||
if config.jax_array:
|
||||
sharding_spec = global_device_array._get_sharding_spec(
|
||||
y[0].shape, y[0].sharding.mesh, y[0].sharding.spec)
|
||||
y_op_sharding = y[0].sharding._to_xla_op_sharding(y[0].ndim)
|
||||
m_size = math.prod([2] + [2] * (len(mesh) - 2))
|
||||
self.assertListEqual(y_op_sharding.tile_assignment_dimensions, [2, 1, 1, m_size])
|
||||
else:
|
||||
sharding_spec = y[0].sharding_spec
|
||||
self.assertEqual(sharding_spec.sharding,
|
||||
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
|
||||
self.assertEqual(sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)) +
|
||||
(pxla.Replicated(2),) * (len(mesh) - 2))
|
||||
self.assertEqual(sharding_spec.sharding,
|
||||
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
|
||||
self.assertEqual(sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)) +
|
||||
(pxla.Replicated(2),) * (len(mesh) - 2))
|
||||
if config.experimental_xmap_spmd_lowering:
|
||||
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
# Make sure that there are non-partial sharding specs in the HLO
|
||||
@ -1105,167 +1107,6 @@ class NamedNNTest(XMapTestCase):
|
||||
atol=1e-4, rtol=2e-2)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class XMapGDATest(XMapTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_basic(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with jax_config.parallel_functions_output_gda(True):
|
||||
f = maps.xmap(
|
||||
lambda x: x,
|
||||
in_axes=({0: "a", 1: "b"}),
|
||||
out_axes=({0: "a", 1: "b"}),
|
||||
axis_resources={"a": "x", "b": "y"})
|
||||
|
||||
out = f(gda_obj)
|
||||
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out.shape, (8, 2))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_xmap_gda_mixed_inputs(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x')
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with jax_config.parallel_functions_output_gda(True):
|
||||
f = maps.xmap(
|
||||
lambda x, y: (x @ x.T, y @ y.T),
|
||||
in_axes=({0: "a"}, ["c", ...]),
|
||||
out_axes=({0: "a"}, ["c", ...]),
|
||||
axis_resources={"a": "x", "c": "x"})
|
||||
|
||||
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
|
||||
out1, out2 = f(gda_obj, input_data)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8,))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
|
||||
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8,))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (2,))
|
||||
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
for i, j in safe_zip(out1.addressable_shards, out2.addressable_shards):
|
||||
self.assertArraysEqual(i.data, j.data)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_xmap_gda_double_input(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gda_obj1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, P('x'), cb)
|
||||
gda_obj2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, P('y'), cb)
|
||||
|
||||
with jax_config.parallel_functions_output_gda(True):
|
||||
f = maps.xmap(
|
||||
lambda x, y: (x @ x.T, y @ y.T),
|
||||
in_axes=({0: "a"}, ["c", ...]),
|
||||
out_axes=({0: "a"}, ["c", ...]),
|
||||
axis_resources={"a": "x", "c": "y"})
|
||||
|
||||
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
|
||||
out1, out2 = f(gda_obj1, gda_obj2)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8,))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
|
||||
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8,))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (4,))
|
||||
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_xmap_gda_sharding_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with jax_config.parallel_functions_output_gda(True):
|
||||
f = maps.xmap(
|
||||
lambda x: x @ x.T,
|
||||
in_axes=({0: "a"}),
|
||||
out_axes=({0: "a"}),
|
||||
axis_resources={"a": "x"})
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
('Got an input GDA to xmap with different partitioning than '
|
||||
'specified in xmap. The partitioning must match.')):
|
||||
f(gda_obj)
|
||||
|
||||
def test_gda_from_pjit_with_xmap_sharding_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((8, 1), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, lambda idx: input_data[idx])
|
||||
with jax_config.parallel_functions_output_gda(True):
|
||||
with global_mesh:
|
||||
out = pjit(
|
||||
lambda x: x, in_shardings=P('x', 'y'), out_shardings=P('x', 'y')
|
||||
)(gda_obj)
|
||||
|
||||
xmap_out = maps.xmap(
|
||||
lambda x: x,
|
||||
in_axes=({0: "a", 1: "b"}),
|
||||
out_axes=({0: "a", 1: "b"}),
|
||||
axis_resources={"a": "x", "b": "y"})(out) # doesn't crash
|
||||
self.assertArraysEqual(xmap_out, input_data)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class XMapArrayTest(XMapTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user