Cleanup ParsedPartitionSpec and remove CanonicalizedParsedPartitionSpec. Also mark user_spec as private.

PiperOrigin-RevId: 676498946
This commit is contained in:
Yash Katariya 2024-09-19 11:38:01 -07:00 committed by jax authors
parent 73bbd80b80
commit c9bbf71ec6
3 changed files with 22 additions and 90 deletions

View File

@ -1041,8 +1041,8 @@ def _create_sharding_for_array(mesh, x, name, api_name):
' then the mesh context manager is not required.')
# A nice user error is raised in prepare_axis_resources.
assert x is None or isinstance(x, ParsedPartitionSpec), x
return (pxla.create_mesh_pspec_sharding(mesh, x)
if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x))
return (pxla.create_mesh_pspec_sharding(mesh, x) if x is None else
pxla.create_mesh_pspec_sharding(mesh, x.get_partition_spec(), x))
def _create_sharding_with_device_backend(device, backend):

View File

@ -18,7 +18,6 @@ import collections
from collections import OrderedDict
from collections.abc import Mapping, Sequence
import dataclasses
import enum
import functools
import itertools
import math
@ -955,43 +954,20 @@ get_single_pspec = lambda p: array_mapping_to_axis_resources(
cast(ArrayMapping, get_array_mapping(p)))
class SpecSync(enum.IntEnum):
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
We use this to make sure we don't show garbage modified values while claiming
that the users have specified them like that.
"""
OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
IN_SYNC = 2 # Entirely in sync
class ParsedPartitionSpec:
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
__slots__ = ('_user_spec', 'partitions')
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
self.unsafe_user_spec = user_spec
def __init__(self, user_spec, partitions):
self._user_spec = user_spec
# None in partitions represents unconstrained dim.
# TODO(yashkatariya): May use a sentinel value.
self.partitions = tuple(partitions)
self.sync = sync
@property
def user_spec(self):
return self.unsynced_user_spec(SpecSync.IN_SYNC)
def get_partition_spec(self) -> PartitionSpec:
if self.sync < SpecSync.IN_SYNC:
return get_single_pspec(self)
if isinstance(self._user_spec, PartitionSpec):
return self._user_spec
else:
if isinstance(self.unsafe_user_spec, PartitionSpec):
return self.unsafe_user_spec
else:
return get_single_pspec(self)
def unsynced_user_spec(self, min_sync):
if self.sync < min_sync:
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
return self.unsafe_user_spec
return get_single_pspec(self)
def insert_axis_partitions(self, dim, val):
parts = self.partitions
@ -999,8 +975,7 @@ class ParsedPartitionSpec:
if too_short > 0:
parts += ((),) * too_short
new_partitions = util.tuple_insert(parts, dim, val)
new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
return ParsedPartitionSpec(None, new_partitions)
@classmethod
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
@ -1027,13 +1002,12 @@ class ParsedPartitionSpec:
return cls(new_entry, axis_specs)
def __hash__(self):
return hash((self.partitions, self.sync))
return hash(self.partitions)
def __eq__(self, other):
if not isinstance(other, ParsedPartitionSpec):
return False
return (self.partitions == other.partitions and
self.sync == other.sync)
return self.partitions == other.partitions
def __len__(self):
return len(self.partitions)
@ -1045,58 +1019,19 @@ class ParsedPartitionSpec:
return iter(self.partitions)
def __repr__(self):
return (f"ParsedPartitionSpec(partitions={self.partitions}, "
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
"""ParsedPartitionSpecs that are canonicalized.
ParsedPartitionSpecs may contain trailing empty tuples, that make them
semantically different in general, and yet in some situations we prefer
to regard them as equivalent. For example, partitions of () and ((),)
cannot be always considered equivalent, since the first one is a valid
spec for a scalar value, while the second is not! However, when either of
those are applied to a 2D array, they both mean that the array is fully
replicated.
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
partitions.
"""
def __init__(self, parsed_pspec: ParsedPartitionSpec):
partitions = list(parsed_pspec.partitions)
while partitions and partitions[-1] == ():
partitions.pop()
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
parsed_pspec.sync)
def __repr__(self):
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
return f"ParsedPartitionSpec(partitions={self.partitions})"
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
# axis_resources for vmap handlers. In such cases you need to preserve the
# `sync` attribute of parsed pspecs.
# PartitionSpec is inferred from the parsed pspec in this case.
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if parsed_pspec is None:
parsed_pspec = prepare_axis_resources(
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
return parsed_pspec
def prepare_axis_resources(axis_resources,
arg_name,
def prepare_axis_resources(axis_resources, arg_name,
allow_unconstrained_dims=False):
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_util.tree_flatten(
@ -1133,9 +1068,11 @@ def _check_unique_resources(axis_resources, arg_name):
if resource_counts.most_common(1)[0][1] > 1:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
if multiple_uses:
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
raise ValueError(
f'A single {arg_name} specification can map every mesh axis to at'
' most one positional dimension, but'
f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
f' for {mesh_lib.show_axes(multiple_uses)}')
# Axis environments
@ -1314,8 +1251,7 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
out.extend(parse_flatten_op_sharding(s, mesh))
return out
elif hlo_sharding.is_replicated():
return [CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec(PartitionSpec(), ()))]
return [ParsedPartitionSpec(PartitionSpec(), ())]
elif hlo_sharding.is_tiled():
mesh_shape = mesh.shape
mesh_axis_order = unflatten_array(
@ -1339,8 +1275,9 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
)
if hlo_sharding.replicate_on_last_tile_dim():
partitions = partitions[:-1]
return [CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec('<internally generated spec>', partitions))]
while partitions and partitions[-1] == ():
partitions.pop()
return [ParsedPartitionSpec(None, partitions)]
else:
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")

View File

@ -5208,11 +5208,6 @@ class UtilTest(jtu.JaxTestCase):
self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(),
P('x', 'y'))
out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec(
P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC)
self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(),
P('x', 'y'))
def test_mesh_with_list_devices(self):
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
self.assertIsInstance(mesh.devices, np.ndarray)