Make mesh_axes on GDA strict by only allowing PartitionSpecs to be consistent with pjit.

PiperOrigin-RevId: 432957496
This commit is contained in:
Yash Katariya 2022-03-07 08:58:41 -08:00 committed by jax authors
parent 17f11e05e0
commit 99a103723c
7 changed files with 72 additions and 73 deletions

View File

@ -21,6 +21,7 @@ import jax
from jax._src import test_util as jtu
from jax._src import util
from jax.config import config
from jax.experimental import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization
from jax.experimental.maps import Mesh
@ -43,7 +44,7 @@ class CheckpointTest(jtu.JaxTestCase):
def test_checkpointing(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
mesh_axes = P('x', 'y')
num = util.prod(global_input_shape)
# First GDA
@ -66,7 +67,7 @@ class CheckpointTest(jtu.JaxTestCase):
def cb3(index):
return np.array([])
global_mesh1d = create_global_mesh((8,), ('x',))
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, [None], cb3)
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
@ -76,7 +77,7 @@ class CheckpointTest(jtu.JaxTestCase):
m1, m2, m3 = serialization.run_deserialization(
[global_mesh, global_mesh, global_mesh1d],
[mesh_axes, ['x'], [None]],
[mesh_axes, P('x'), P(None)],
tspecs)
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
@ -109,7 +110,7 @@ class CheckpointTest(jtu.JaxTestCase):
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
['x', 'y'], cb1)
P('x', 'y'), cb1)
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
ckpt_paths = [str(ckpt_dir1)]
@ -119,7 +120,7 @@ class CheckpointTest(jtu.JaxTestCase):
m1, = serialization.run_deserialization(
[create_global_mesh((4, 2), ('x', 'y'))],
[['x', 'y']],
[P('x', 'y')],
tspecs,
[(12, 2)],
)

View File

@ -27,7 +27,7 @@ from jax._src.api import device_put
from jax.interpreters.sharded_jit import PartitionSpec
Shape = Tuple[int, ...]
MeshAxes = Sequence[Union[str, Tuple[str], None]]
MeshAxes = PartitionSpec
DeviceArray = xc.Buffer
Device = xc.Device
ArrayLike = Union[np.ndarray, DeviceArray]
@ -50,11 +50,7 @@ def _get_array_mapping(mesh_axes):
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
return get_array_mapping(parsed_pspec)
@ -297,7 +293,7 @@ class GlobalDeviceArray:
self._local_shards = self._create_local_shards()
ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
assert all(db.shape == ss for db in device_buffers), (
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape, got: {[db.shape for db in device_buffers]}")
@ -322,8 +318,8 @@ class GlobalDeviceArray:
def __repr__(self):
return (f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype}, '
f'global_mesh_shape={dict(self._global_mesh.shape)}, '
f'mesh_axes={self._mesh_axes})')
f'global_mesh_shape={dict(self.mesh.shape)}, '
f'mesh_axes={self.mesh_axes})')
@property
def shape(self) -> Shape:
@ -341,6 +337,10 @@ class GlobalDeviceArray:
def mesh(self):
return self._global_mesh
@property
def mesh_axes(self) -> MeshAxes:
return self._mesh_axes
@property
def is_fully_replicated(self) -> bool:
return self.shape == self.local_data(0).shape
@ -350,7 +350,7 @@ class GlobalDeviceArray:
global_indices_rid = self._gda_fast_path_args.global_indices_replica_ids
else:
global_indices_rid = get_shard_indices_replica_ids(
self._global_shape, self._global_mesh, self._mesh_axes)
self._global_shape, self._global_mesh, self.mesh_axes)
out = []
for db in self._device_buffers:
@ -379,7 +379,7 @@ class GlobalDeviceArray:
# Also as this a cached property, once calculated, it should be cached. So
# multiple accesses should be cheap.
global_indices_rid = get_shard_indices_replica_ids(
self._global_shape, self._global_mesh, self._mesh_axes)
self._global_shape, self._global_mesh, self.mesh_axes)
device_to_buffer = dict((db.device(), db) for db in self._device_buffers)
global_shards = []
for device, (index, rid) in global_indices_rid.items():
@ -410,10 +410,11 @@ class GlobalDeviceArray:
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 8)
>>> mesh_axes = ['x', 'y']
>>> mesh_axes = P('x', 'y')
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
@ -456,10 +457,11 @@ class GlobalDeviceArray:
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = ['x']
>>> mesh_axes = P('x')
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
@ -502,10 +504,11 @@ class GlobalDeviceArray:
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = [('x', 'y')]
>>> mesh_axes = P(('x', 'y'))
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...

View File

@ -2028,7 +2028,7 @@ def _check_gda_xmap_partitioning(axis_resources, resource_env,
f"mesh: {resource_env.physical_mesh},\n"
f"GDA mesh: {arg.mesh}")
gda_array_mapping = _get_array_mapping(arg._mesh_axes)
gda_array_mapping = _get_array_mapping(arg.mesh_axes)
if gda_array_mapping != xmap_array_mapping:
raise ValueError(
"Got an input GDA to xmap with different partitioning than "

View File

@ -1053,7 +1053,9 @@ def _create_cpspec(x):
def _maybe_replace_from_gda_with_pspec(
in_axis_resources_flat: CanonicalizedParsedPartitionSpec, arg) -> CanonicalizedParsedPartitionSpec:
if isinstance(arg, GDA):
gda_cpspec = gda_mesh_axes_to_canonicalized_parsed_pspec(arg._mesh_axes)
gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(
arg.mesh_axes, arg_name="GDA mesh_axes"))
assert type(gda_cpspec) is CanonicalizedParsedPartitionSpec
if (not _is_from_gda(in_axis_resources_flat) and
in_axis_resources_flat != gda_cpspec):
@ -1066,13 +1068,6 @@ def _maybe_replace_from_gda_with_pspec(
return gda_cpspec
return in_axis_resources_flat
def gda_mesh_axes_to_canonicalized_parsed_pspec(mesh_axes) -> CanonicalizedParsedPartitionSpec:
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
return CanonicalizedParsedPartitionSpec(ParsedPartitionSpec.from_user_input(
pspec, arg_name='GDA mesh_axes'))
def _maybe_check_pjit_gda_mesh(args, mesh):
for x in args:

View File

@ -414,20 +414,24 @@ mesh devices ndarray would have to be transposed before flattening and assignmen
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
AxisResource = Tuple[Optional[Tuple[Any, ...]], ...]
def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource:
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
# TODO(yashkatariya): Move PartitionSpec into a place where all files can
# import it without cyclic dependency.
from jax.interpreters.sharded_jit import PartitionSpec
if not array_mapping:
return tuple()
return PartitionSpec()
max_index = -1
reverse_map = defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
if index > max_index:
max_index = index
return tuple(
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
)
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
for i in range(max_index + 1))
return PartitionSpec(*partitions)
def local_aval_to_result_handler(
aval: core.AbstractValue,
@ -465,14 +469,13 @@ local_result_handlers[ConcreteArray] = sda_array_result_handler
def global_aval_to_result_handler(
aval: core.AbstractValue,
out_axis_resources: Optional[AxisResource], global_mesh,
aval: core.AbstractValue, out_axis_resources, global_mesh,
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The global output AbstractValue.
out_axis_resources: A tuple specifying the sharding of outputs.
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.

View File

@ -35,33 +35,29 @@ config.parse_flags_with_absl()
class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", ["x", "y"],
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x_y_pspec", P("x", "y"),
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", ["x"],
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", ["y"],
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", [None, "y"],
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", [("x", "y")],
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", [],
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
@ -79,6 +75,7 @@ class GDATest(jtu.JaxTestCase):
mesh_axes, cb)
self.assertEqual(gda.ndim, 2)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.mesh_axes, mesh_axes)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
@ -103,17 +100,17 @@ class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y_z", ["x", "y", "z"],
("mesh_x_y_z", P("x", "y", "z"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
(4, 2, 1),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_xy_z", [("x", "y"), "z"],
("mesh_xy_z", P(("x", "y"), "z"),
((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
(2, 2, 2),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_z", ["z"],
("mesh_z", P("z"),
((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
(4, 4, 2),
[0, 0, 1, 1, 2, 2, 3, 3]),
@ -143,13 +140,13 @@ class GDATest(jtu.JaxTestCase):
self.assertListEqual(replica_ids, expected_replica_ids)
@parameterized.named_parameters(
("mesh_x", ["x"],
("mesh_x", P("x"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2),), (slice(2, 4),)),
(2,),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_none", [],
("mesh_none", P(),
((slice(None),), (slice(None),)),
(16,),
[0, 1, 2, 3, 4, 5, 6, 7]),
@ -179,7 +176,7 @@ class GDATest(jtu.JaxTestCase):
def test_gda_shape_0_1d_mesh(self):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (0,)
mesh_axes = [None]
mesh_axes = P(None)
def cb(index):
return np.array([])
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
@ -197,7 +194,7 @@ class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", ["x", "y"],
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
@ -233,7 +230,7 @@ class GDATest(jtu.JaxTestCase):
def test_gda_batched_callback(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
@ -253,7 +250,7 @@ class GDATest(jtu.JaxTestCase):
def test_gda_batched_callback_with_devices(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']
mesh_axes = P('x')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
@ -279,7 +276,7 @@ class GDATest(jtu.JaxTestCase):
def test_gda_str_repr(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
@ -289,9 +286,9 @@ class GDATest(jtu.JaxTestCase):
self.assertEqual(str(gda),
'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
self.assertEqual(
repr(gda),
("GlobalDeviceArray(shape=(8, 2), dtype=int32, "
"global_mesh_shape={'x': 4, 'y': 2}, mesh_axes=[('x', 'y')])"))
repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
"global_mesh_shape={'x': 4, 'y': 2}, "
"mesh_axes=PartitionSpec(('x', 'y'),))"))
def test_gda_equality_raises_not_implemented(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
@ -329,7 +326,7 @@ class GDATest(jtu.JaxTestCase):
[devices[7], devices[5]]])
global_mesh = Mesh(mesh_devices, ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
mesh_axes = P('x', 'y')
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
@ -347,7 +344,7 @@ class GDATest(jtu.JaxTestCase):
def test_gda_block_until_ready(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)

View File

@ -991,7 +991,7 @@ class GDAPjitTest(jtu.JaxTestCase):
def test_pjit_gda_mesh_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
mesh_axes = P('x', 'y')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(index):
@ -1012,7 +1012,7 @@ class GDAPjitTest(jtu.JaxTestCase):
def test_pjit_gda_wrong_resource_for_gda_input(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']
mesh_axes = P('x')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(index):
@ -1066,7 +1066,7 @@ class GDAPjitTest(jtu.JaxTestCase):
def test_partition_spec_mismatch_semantically_equivalent(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [None]
mesh_axes = P(None)
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
@ -1082,15 +1082,15 @@ class GDAPjitTest(jtu.JaxTestCase):
return x
output_gda = f(gda_obj)
# Ensure output_gda._mesh_axes = P() is matched with P(None).
self.assertEqual(output_gda._mesh_axes, ())
# Ensure output_gda.mesh_axes = P() is matched with P(None).
self.assertEqual(output_gda.mesh_axes, ())
# P(None) is in_axis_resources.
f(output_gda)
def test_from_gda_duplicates(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
mesh_axes = P('x', 'y')
input_gda = create_gda(global_input_shape, global_mesh, mesh_axes)
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
@ -1114,7 +1114,7 @@ class GDAPjitTest(jtu.JaxTestCase):
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
out_gda = f(input_gda)
self.assertEqual(out_gda._mesh_axes, ())
self.assertEqual(out_gda.mesh_axes, ())
before_cache = pjit_lib._pjit_lower.cache_info()
f(out_gda)
@ -1395,10 +1395,10 @@ class UtilTest(jtu.JaxTestCase):
roundtrip(P(*spec))
@parameterized.named_parameters(
("linear", {'x': 0, 'y': 1, 'z': 2}, (('x',), ('y',), ('z',))),
("combine", {'x': 0, 'y': 0, 'z': 1}, (('x', 'y'), ('z',))),
("skip", {'x': 0, 'y': 0, 'z': 2}, (('x', 'y'), None, ('z',))),
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, (('x',), ('y',), None, ('z',))),
("linear", {'x': 0, 'y': 1, 'z': 2}, P(('x',), ('y',), ('z',))),
("combine", {'x': 0, 'y': 0, 'z': 1}, P(('x', 'y'), ('z',))),
("skip", {'x': 0, 'y': 0, 'z': 2}, P(('x', 'y'), None, ('z',))),
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, P(('x',), ('y',), None, ('z',))),
)
def test_array_mapping_to_axis_resources(self, inp, expected_out):
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)