mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Cleanup ParsedPartitionSpec
and remove CanonicalizedParsedPartitionSpec
. Also mark user_spec
as private.
PiperOrigin-RevId: 676498946
This commit is contained in:
parent
73bbd80b80
commit
c9bbf71ec6
@ -1041,8 +1041,8 @@ def _create_sharding_for_array(mesh, x, name, api_name):
|
||||
' then the mesh context manager is not required.')
|
||||
# A nice user error is raised in prepare_axis_resources.
|
||||
assert x is None or isinstance(x, ParsedPartitionSpec), x
|
||||
return (pxla.create_mesh_pspec_sharding(mesh, x)
|
||||
if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x))
|
||||
return (pxla.create_mesh_pspec_sharding(mesh, x) if x is None else
|
||||
pxla.create_mesh_pspec_sharding(mesh, x.get_partition_spec(), x))
|
||||
|
||||
|
||||
def _create_sharding_with_device_backend(device, backend):
|
||||
|
@ -18,7 +18,6 @@ import collections
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping, Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
@ -955,43 +954,20 @@ get_single_pspec = lambda p: array_mapping_to_axis_resources(
|
||||
cast(ArrayMapping, get_array_mapping(p)))
|
||||
|
||||
|
||||
class SpecSync(enum.IntEnum):
|
||||
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
|
||||
|
||||
We use this to make sure we don't show garbage modified values while claiming
|
||||
that the users have specified them like that.
|
||||
"""
|
||||
OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
|
||||
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
|
||||
IN_SYNC = 2 # Entirely in sync
|
||||
|
||||
class ParsedPartitionSpec:
|
||||
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
|
||||
__slots__ = ('_user_spec', 'partitions')
|
||||
|
||||
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
|
||||
self.unsafe_user_spec = user_spec
|
||||
def __init__(self, user_spec, partitions):
|
||||
self._user_spec = user_spec
|
||||
# None in partitions represents unconstrained dim.
|
||||
# TODO(yashkatariya): May use a sentinel value.
|
||||
self.partitions = tuple(partitions)
|
||||
self.sync = sync
|
||||
|
||||
@property
|
||||
def user_spec(self):
|
||||
return self.unsynced_user_spec(SpecSync.IN_SYNC)
|
||||
|
||||
def get_partition_spec(self) -> PartitionSpec:
|
||||
if self.sync < SpecSync.IN_SYNC:
|
||||
return get_single_pspec(self)
|
||||
if isinstance(self._user_spec, PartitionSpec):
|
||||
return self._user_spec
|
||||
else:
|
||||
if isinstance(self.unsafe_user_spec, PartitionSpec):
|
||||
return self.unsafe_user_spec
|
||||
else:
|
||||
return get_single_pspec(self)
|
||||
|
||||
def unsynced_user_spec(self, min_sync):
|
||||
if self.sync < min_sync:
|
||||
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
|
||||
return self.unsafe_user_spec
|
||||
return get_single_pspec(self)
|
||||
|
||||
def insert_axis_partitions(self, dim, val):
|
||||
parts = self.partitions
|
||||
@ -999,8 +975,7 @@ class ParsedPartitionSpec:
|
||||
if too_short > 0:
|
||||
parts += ((),) * too_short
|
||||
new_partitions = util.tuple_insert(parts, dim, val)
|
||||
new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
|
||||
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
|
||||
return ParsedPartitionSpec(None, new_partitions)
|
||||
|
||||
@classmethod
|
||||
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
|
||||
@ -1027,13 +1002,12 @@ class ParsedPartitionSpec:
|
||||
return cls(new_entry, axis_specs)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.partitions, self.sync))
|
||||
return hash(self.partitions)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ParsedPartitionSpec):
|
||||
return False
|
||||
return (self.partitions == other.partitions and
|
||||
self.sync == other.sync)
|
||||
return self.partitions == other.partitions
|
||||
|
||||
def __len__(self):
|
||||
return len(self.partitions)
|
||||
@ -1045,58 +1019,19 @@ class ParsedPartitionSpec:
|
||||
return iter(self.partitions)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"ParsedPartitionSpec(partitions={self.partitions}, "
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
|
||||
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
||||
"""ParsedPartitionSpecs that are canonicalized.
|
||||
|
||||
ParsedPartitionSpecs may contain trailing empty tuples, that make them
|
||||
semantically different in general, and yet in some situations we prefer
|
||||
to regard them as equivalent. For example, partitions of () and ((),)
|
||||
cannot be always considered equivalent, since the first one is a valid
|
||||
spec for a scalar value, while the second is not! However, when either of
|
||||
those are applied to a 2D array, they both mean that the array is fully
|
||||
replicated.
|
||||
|
||||
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
|
||||
partitions.
|
||||
"""
|
||||
|
||||
def __init__(self, parsed_pspec: ParsedPartitionSpec):
|
||||
partitions = list(parsed_pspec.partitions)
|
||||
while partitions and partitions[-1] == ():
|
||||
partitions.pop()
|
||||
|
||||
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
|
||||
parsed_pspec.sync)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
return f"ParsedPartitionSpec(partitions={self.partitions})"
|
||||
|
||||
|
||||
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
||||
# This split exists because you can pass `_parsed_pspec` that has been
|
||||
# modified from the original. For example: Adding extra dimension to
|
||||
# axis_resources for vmap handlers. In such cases you need to preserve the
|
||||
# `sync` attribute of parsed pspecs.
|
||||
# PartitionSpec is inferred from the parsed pspec in this case.
|
||||
# TODO(yaskatariya): Remove this and replace this with a normalized
|
||||
# representation of Parsed Pspec
|
||||
if parsed_pspec is None:
|
||||
parsed_pspec = prepare_axis_resources(
|
||||
PartitionSpec() if spec is None else spec,
|
||||
"NamedSharding spec", allow_unconstrained_dims=True)
|
||||
|
||||
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
||||
return parsed_pspec
|
||||
|
||||
|
||||
def prepare_axis_resources(axis_resources,
|
||||
arg_name,
|
||||
def prepare_axis_resources(axis_resources, arg_name,
|
||||
allow_unconstrained_dims=False):
|
||||
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
|
||||
entries, treedef = tree_util.tree_flatten(
|
||||
@ -1133,9 +1068,11 @@ def _check_unique_resources(axis_resources, arg_name):
|
||||
if resource_counts.most_common(1)[0][1] > 1:
|
||||
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
||||
if multiple_uses:
|
||||
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
|
||||
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
|
||||
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
|
||||
raise ValueError(
|
||||
f'A single {arg_name} specification can map every mesh axis to at'
|
||||
' most one positional dimension, but'
|
||||
f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
|
||||
f' for {mesh_lib.show_axes(multiple_uses)}')
|
||||
|
||||
# Axis environments
|
||||
|
||||
@ -1314,8 +1251,7 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
|
||||
out.extend(parse_flatten_op_sharding(s, mesh))
|
||||
return out
|
||||
elif hlo_sharding.is_replicated():
|
||||
return [CanonicalizedParsedPartitionSpec(
|
||||
ParsedPartitionSpec(PartitionSpec(), ()))]
|
||||
return [ParsedPartitionSpec(PartitionSpec(), ())]
|
||||
elif hlo_sharding.is_tiled():
|
||||
mesh_shape = mesh.shape
|
||||
mesh_axis_order = unflatten_array(
|
||||
@ -1339,8 +1275,9 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
|
||||
)
|
||||
if hlo_sharding.replicate_on_last_tile_dim():
|
||||
partitions = partitions[:-1]
|
||||
return [CanonicalizedParsedPartitionSpec(
|
||||
ParsedPartitionSpec('<internally generated spec>', partitions))]
|
||||
while partitions and partitions[-1] == ():
|
||||
partitions.pop()
|
||||
return [ParsedPartitionSpec(None, partitions)]
|
||||
else:
|
||||
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")
|
||||
|
||||
|
@ -5208,11 +5208,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(),
|
||||
P('x', 'y'))
|
||||
|
||||
out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec(
|
||||
P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC)
|
||||
self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(),
|
||||
P('x', 'y'))
|
||||
|
||||
def test_mesh_with_list_devices(self):
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
|
||||
self.assertIsInstance(mesh.devices, np.ndarray)
|
||||
|
Loading…
x
Reference in New Issue
Block a user