mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Make mesh_axes
on GDA strict by only allowing PartitionSpecs to be consistent with pjit.
PiperOrigin-RevId: 432957496
This commit is contained in:
parent
17f11e05e0
commit
99a103723c
@ -21,6 +21,7 @@ import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax.config import config
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
from jax.experimental.maps import Mesh
|
||||
@ -43,7 +44,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
def test_checkpointing(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x', 'y']
|
||||
mesh_axes = P('x', 'y')
|
||||
num = util.prod(global_input_shape)
|
||||
|
||||
# First GDA
|
||||
@ -66,7 +67,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
def cb3(index):
|
||||
return np.array([])
|
||||
global_mesh1d = create_global_mesh((8,), ('x',))
|
||||
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, [None], cb3)
|
||||
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
|
||||
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
|
||||
@ -76,7 +77,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
m1, m2, m3 = serialization.run_deserialization(
|
||||
[global_mesh, global_mesh, global_mesh1d],
|
||||
[mesh_axes, ['x'], [None]],
|
||||
[mesh_axes, P('x'), P(None)],
|
||||
tspecs)
|
||||
|
||||
self.assertArraysEqual(m1.local_shards[0].data.to_py(),
|
||||
@ -109,7 +110,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
def cb1(index):
|
||||
return global_input_data1[index]
|
||||
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
['x', 'y'], cb1)
|
||||
P('x', 'y'), cb1)
|
||||
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
|
||||
|
||||
ckpt_paths = [str(ckpt_dir1)]
|
||||
@ -119,7 +120,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
|
||||
m1, = serialization.run_deserialization(
|
||||
[create_global_mesh((4, 2), ('x', 'y'))],
|
||||
[['x', 'y']],
|
||||
[P('x', 'y')],
|
||||
tspecs,
|
||||
[(12, 2)],
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ from jax._src.api import device_put
|
||||
from jax.interpreters.sharded_jit import PartitionSpec
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
MeshAxes = Sequence[Union[str, Tuple[str], None]]
|
||||
MeshAxes = PartitionSpec
|
||||
DeviceArray = xc.Buffer
|
||||
Device = xc.Device
|
||||
ArrayLike = Union[np.ndarray, DeviceArray]
|
||||
@ -50,11 +50,7 @@ def _get_array_mapping(mesh_axes):
|
||||
# Import here to avoid cyclic import error when importing gda in pjit.py.
|
||||
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
|
||||
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
else:
|
||||
pspec = mesh_axes
|
||||
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
|
||||
parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
|
||||
return get_array_mapping(parsed_pspec)
|
||||
|
||||
|
||||
@ -297,7 +293,7 @@ class GlobalDeviceArray:
|
||||
|
||||
self._local_shards = self._create_local_shards()
|
||||
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
assert all(db.shape == ss for db in device_buffers), (
|
||||
f"Expected shard shape {ss} doesn't match the device buffer "
|
||||
f"shape, got: {[db.shape for db in device_buffers]}")
|
||||
@ -322,8 +318,8 @@ class GlobalDeviceArray:
|
||||
|
||||
def __repr__(self):
|
||||
return (f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype}, '
|
||||
f'global_mesh_shape={dict(self._global_mesh.shape)}, '
|
||||
f'mesh_axes={self._mesh_axes})')
|
||||
f'global_mesh_shape={dict(self.mesh.shape)}, '
|
||||
f'mesh_axes={self.mesh_axes})')
|
||||
|
||||
@property
|
||||
def shape(self) -> Shape:
|
||||
@ -341,6 +337,10 @@ class GlobalDeviceArray:
|
||||
def mesh(self):
|
||||
return self._global_mesh
|
||||
|
||||
@property
|
||||
def mesh_axes(self) -> MeshAxes:
|
||||
return self._mesh_axes
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == self.local_data(0).shape
|
||||
@ -350,7 +350,7 @@ class GlobalDeviceArray:
|
||||
global_indices_rid = self._gda_fast_path_args.global_indices_replica_ids
|
||||
else:
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
self._global_shape, self._global_mesh, self._mesh_axes)
|
||||
self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
|
||||
out = []
|
||||
for db in self._device_buffers:
|
||||
@ -379,7 +379,7 @@ class GlobalDeviceArray:
|
||||
# Also as this a cached property, once calculated, it should be cached. So
|
||||
# multiple accesses should be cheap.
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
self._global_shape, self._global_mesh, self._mesh_axes)
|
||||
self._global_shape, self._global_mesh, self.mesh_axes)
|
||||
device_to_buffer = dict((db.device(), db) for db in self._device_buffers)
|
||||
global_shards = []
|
||||
for device, (index, rid) in global_indices_rid.items():
|
||||
@ -410,10 +410,11 @@ class GlobalDeviceArray:
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> global_input_shape = (8, 8)
|
||||
>>> mesh_axes = ['x', 'y']
|
||||
>>> mesh_axes = P('x', 'y')
|
||||
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
@ -456,10 +457,11 @@ class GlobalDeviceArray:
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> global_input_shape = (8, 2)
|
||||
>>> mesh_axes = ['x']
|
||||
>>> mesh_axes = P('x')
|
||||
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
@ -502,10 +504,11 @@ class GlobalDeviceArray:
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> global_input_shape = (8, 2)
|
||||
>>> mesh_axes = [('x', 'y')]
|
||||
>>> mesh_axes = P(('x', 'y'))
|
||||
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
|
@ -2028,7 +2028,7 @@ def _check_gda_xmap_partitioning(axis_resources, resource_env,
|
||||
f"mesh: {resource_env.physical_mesh},\n"
|
||||
f"GDA mesh: {arg.mesh}")
|
||||
|
||||
gda_array_mapping = _get_array_mapping(arg._mesh_axes)
|
||||
gda_array_mapping = _get_array_mapping(arg.mesh_axes)
|
||||
if gda_array_mapping != xmap_array_mapping:
|
||||
raise ValueError(
|
||||
"Got an input GDA to xmap with different partitioning than "
|
||||
|
@ -1053,7 +1053,9 @@ def _create_cpspec(x):
|
||||
def _maybe_replace_from_gda_with_pspec(
|
||||
in_axis_resources_flat: CanonicalizedParsedPartitionSpec, arg) -> CanonicalizedParsedPartitionSpec:
|
||||
if isinstance(arg, GDA):
|
||||
gda_cpspec = gda_mesh_axes_to_canonicalized_parsed_pspec(arg._mesh_axes)
|
||||
gda_cpspec = CanonicalizedParsedPartitionSpec(
|
||||
ParsedPartitionSpec.from_user_input(
|
||||
arg.mesh_axes, arg_name="GDA mesh_axes"))
|
||||
assert type(gda_cpspec) is CanonicalizedParsedPartitionSpec
|
||||
if (not _is_from_gda(in_axis_resources_flat) and
|
||||
in_axis_resources_flat != gda_cpspec):
|
||||
@ -1066,13 +1068,6 @@ def _maybe_replace_from_gda_with_pspec(
|
||||
return gda_cpspec
|
||||
return in_axis_resources_flat
|
||||
|
||||
def gda_mesh_axes_to_canonicalized_parsed_pspec(mesh_axes) -> CanonicalizedParsedPartitionSpec:
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
else:
|
||||
pspec = mesh_axes
|
||||
return CanonicalizedParsedPartitionSpec(ParsedPartitionSpec.from_user_input(
|
||||
pspec, arg_name='GDA mesh_axes'))
|
||||
|
||||
def _maybe_check_pjit_gda_mesh(args, mesh):
|
||||
for x in args:
|
||||
|
@ -414,20 +414,24 @@ mesh devices ndarray would have to be transposed before flattening and assignmen
|
||||
"""
|
||||
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
||||
|
||||
AxisResource = Tuple[Optional[Tuple[Any, ...]], ...]
|
||||
|
||||
def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource:
|
||||
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
# TODO(yashkatariya): Move PartitionSpec into a place where all files can
|
||||
# import it without cyclic dependency.
|
||||
from jax.interpreters.sharded_jit import PartitionSpec
|
||||
|
||||
if not array_mapping:
|
||||
return tuple()
|
||||
return PartitionSpec()
|
||||
max_index = -1
|
||||
reverse_map = defaultdict(list)
|
||||
for axis, index in array_mapping.items():
|
||||
reverse_map[index].append(axis)
|
||||
if index > max_index:
|
||||
max_index = index
|
||||
return tuple(
|
||||
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
|
||||
)
|
||||
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
|
||||
for i in range(max_index + 1))
|
||||
return PartitionSpec(*partitions)
|
||||
|
||||
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
@ -465,14 +469,13 @@ local_result_handlers[ConcreteArray] = sda_array_result_handler
|
||||
|
||||
|
||||
def global_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
out_axis_resources: Optional[AxisResource], global_mesh,
|
||||
aval: core.AbstractValue, out_axis_resources, global_mesh,
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
aval: The global output AbstractValue.
|
||||
out_axis_resources: A tuple specifying the sharding of outputs.
|
||||
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
|
||||
Used for creating GSDAs.
|
||||
global_mesh: The global device mesh that generated this output. Used
|
||||
for creating GSDAs.
|
||||
|
@ -35,33 +35,29 @@ config.parse_flags_with_absl()
|
||||
class GDATest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", ["x", "y"],
|
||||
("mesh_x_y", P("x", "y"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
|
||||
(2, 1),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0], False),
|
||||
("mesh_x_y_pspec", P("x", "y"),
|
||||
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
|
||||
(2, 1),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0], False),
|
||||
("mesh_x", ["x"],
|
||||
("mesh_x", P("x"),
|
||||
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
|
||||
(2, 2),
|
||||
[0, 1, 0, 1, 0, 1, 0, 1], False),
|
||||
("mesh_y", ["y"],
|
||||
("mesh_y", P("y"),
|
||||
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
|
||||
(4, 2),
|
||||
[0, 0, 1, 1, 2, 2, 3, 3], False),
|
||||
("mesh_none_y", [None, "y"],
|
||||
("mesh_none_y", P(None, "y"),
|
||||
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
|
||||
(8, 1),
|
||||
[0, 0, 1, 1, 2, 2, 3, 3], False),
|
||||
("mesh_xy", [("x", "y")],
|
||||
("mesh_xy", P(("x", "y")),
|
||||
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
|
||||
(1, 2),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0], False),
|
||||
("mesh_fully_replicated", [],
|
||||
("mesh_fully_replicated", P(),
|
||||
((slice(None), slice(None)), (slice(None), slice(None))),
|
||||
(8, 2),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7], True),
|
||||
@ -79,6 +75,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.ndim, 2)
|
||||
self.assertEqual(gda.size, 16)
|
||||
self.assertEqual(gda.mesh_axes, mesh_axes)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
@ -103,17 +100,17 @@ class GDATest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y_z", ["x", "y", "z"],
|
||||
("mesh_x_y_z", P("x", "y", "z"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
|
||||
(4, 2, 1),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("mesh_xy_z", [("x", "y"), "z"],
|
||||
("mesh_xy_z", P(("x", "y"), "z"),
|
||||
((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
|
||||
(2, 2, 2),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("mesh_z", ["z"],
|
||||
("mesh_z", P("z"),
|
||||
((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
|
||||
(4, 4, 2),
|
||||
[0, 0, 1, 1, 2, 2, 3, 3]),
|
||||
@ -143,13 +140,13 @@ class GDATest(jtu.JaxTestCase):
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x", ["x"],
|
||||
("mesh_x", P("x"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 2),), (slice(2, 4),)),
|
||||
(2,),
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("mesh_none", [],
|
||||
("mesh_none", P(),
|
||||
((slice(None),), (slice(None),)),
|
||||
(16,),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
@ -179,7 +176,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
def test_gda_shape_0_1d_mesh(self):
|
||||
global_mesh = jtu.create_global_mesh((8,), ('x'))
|
||||
global_input_shape = (0,)
|
||||
mesh_axes = [None]
|
||||
mesh_axes = P(None)
|
||||
def cb(index):
|
||||
return np.array([])
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
@ -197,7 +194,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", ["x", "y"],
|
||||
("mesh_x_y", P("x", "y"),
|
||||
# There are more slices but for convienient purposes, checking for only
|
||||
# 2. The indices + shard_shape + replica_id should be unique enough.
|
||||
((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
|
||||
@ -233,7 +230,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
def test_gda_batched_callback(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = [('x', 'y')]
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
@ -253,7 +250,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
def test_gda_batched_callback_with_devices(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x']
|
||||
mesh_axes = P('x')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
@ -279,7 +276,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
def test_gda_str_repr(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = [('x', 'y')]
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
@ -289,9 +286,9 @@ class GDATest(jtu.JaxTestCase):
|
||||
self.assertEqual(str(gda),
|
||||
'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
|
||||
self.assertEqual(
|
||||
repr(gda),
|
||||
("GlobalDeviceArray(shape=(8, 2), dtype=int32, "
|
||||
"global_mesh_shape={'x': 4, 'y': 2}, mesh_axes=[('x', 'y')])"))
|
||||
repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
|
||||
"global_mesh_shape={'x': 4, 'y': 2}, "
|
||||
"mesh_axes=PartitionSpec(('x', 'y'),))"))
|
||||
|
||||
def test_gda_equality_raises_not_implemented(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
@ -329,7 +326,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
[devices[7], devices[5]]])
|
||||
global_mesh = Mesh(mesh_devices, ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x', 'y']
|
||||
mesh_axes = P('x', 'y')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
|
||||
@ -347,7 +344,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
def test_gda_block_until_ready(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = [('x', 'y')]
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
|
@ -991,7 +991,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
def test_pjit_gda_mesh_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x', 'y']
|
||||
mesh_axes = P('x', 'y')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
@ -1012,7 +1012,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
def test_pjit_gda_wrong_resource_for_gda_input(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x']
|
||||
mesh_axes = P('x')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
@ -1066,7 +1066,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
def test_partition_spec_mismatch_semantically_equivalent(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = [None]
|
||||
mesh_axes = P(None)
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
@ -1082,15 +1082,15 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
return x
|
||||
|
||||
output_gda = f(gda_obj)
|
||||
# Ensure output_gda._mesh_axes = P() is matched with P(None).
|
||||
self.assertEqual(output_gda._mesh_axes, ())
|
||||
# Ensure output_gda.mesh_axes = P() is matched with P(None).
|
||||
self.assertEqual(output_gda.mesh_axes, ())
|
||||
# P(None) is in_axis_resources.
|
||||
f(output_gda)
|
||||
|
||||
def test_from_gda_duplicates(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x', 'y']
|
||||
mesh_axes = P('x', 'y')
|
||||
input_gda = create_gda(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
|
||||
@ -1114,7 +1114,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
out_gda = f(input_gda)
|
||||
self.assertEqual(out_gda._mesh_axes, ())
|
||||
self.assertEqual(out_gda.mesh_axes, ())
|
||||
|
||||
before_cache = pjit_lib._pjit_lower.cache_info()
|
||||
f(out_gda)
|
||||
@ -1395,10 +1395,10 @@ class UtilTest(jtu.JaxTestCase):
|
||||
roundtrip(P(*spec))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("linear", {'x': 0, 'y': 1, 'z': 2}, (('x',), ('y',), ('z',))),
|
||||
("combine", {'x': 0, 'y': 0, 'z': 1}, (('x', 'y'), ('z',))),
|
||||
("skip", {'x': 0, 'y': 0, 'z': 2}, (('x', 'y'), None, ('z',))),
|
||||
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, (('x',), ('y',), None, ('z',))),
|
||||
("linear", {'x': 0, 'y': 1, 'z': 2}, P(('x',), ('y',), ('z',))),
|
||||
("combine", {'x': 0, 'y': 0, 'z': 1}, P(('x', 'y'), ('z',))),
|
||||
("skip", {'x': 0, 'y': 0, 'z': 2}, P(('x', 'y'), None, ('z',))),
|
||||
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, P(('x',), ('y',), None, ('z',))),
|
||||
)
|
||||
def test_array_mapping_to_axis_resources(self, inp, expected_out):
|
||||
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user