Speed up NamedSharding construction.

* Compute the size of a mesh eagerly. We're almost always going to need this, because NamedSharding's constructor asks for it.
* Speed up mesh equality. It's likely we have only one mesh, and the identity equality test will hit. Do it first.
* don't call _prepare_axis_resources in ParsedPartitionSpec construction. This does a bunch of pointless tree flattening and list manipulation but we know we have exactly one PartitionSpec and can directly do the check we need, which is _check_unique_resources.
* only call _check_unique_resources on PartitionSpecs; it's easy to avoid doing it in other cases and then we don't need a bunch of isinstance checks.
* avoid use of collections.Counter when checking for unique resources. collections.Counter has a surprisingly slow isinstance test.

PiperOrigin-RevId: 724431847
This commit is contained in:
Peter Hawkins 2025-02-07 12:20:08 -08:00 committed by jax authors
parent cd0753751c
commit f21b0f03b4
2 changed files with 46 additions and 32 deletions

View File

@ -228,6 +228,7 @@ class Mesh(contextlib.ContextDecorator):
self.axis_names = axis_names
self.axis_types = axis_types
self._axis_types_tuple = axis_types_tuple
self._size = math.prod(self.shape.values()) if self.devices.ndim else 0
_mesh_object_dict[key] = self
return self
@ -236,12 +237,12 @@ class Mesh(contextlib.ContextDecorator):
{'axis_types': self.axis_types})
def __eq__(self, other):
if not isinstance(other, Mesh):
return False
# This is a performance optimization. Comparing thousands of devices
# can be expensive.
if id(self) == id(other):
return True
if not isinstance(other, Mesh):
return False
return (self.axis_names == other.axis_names and
self.devices.shape == other.devices.shape and
self._axis_types_tuple == other._axis_types_tuple and
@ -306,7 +307,7 @@ class Mesh(contextlib.ContextDecorator):
@property
def size(self):
return math.prod(self.shape.values()) if self.devices.ndim else 0
return self._size
@property
def empty(self):
@ -413,6 +414,7 @@ class AbstractMesh:
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
else:
self._axis_names, self._axis_sizes = (), ()
self._size = math.prod(self._axis_sizes) if self._axis_sizes else 0
self.axis_types = ({AxisTypes.Auto: self._axis_names}
if axis_types is None else axis_types)
self._axis_types_tuple = to_axis_types_tuple(self.axis_types)
@ -426,10 +428,10 @@ class AbstractMesh:
return hash((self.shape_tuple, self._axis_types_tuple))
def __eq__(self, other):
if not isinstance(other, AbstractMesh):
return False
if id(self) == id(other):
return True
if not isinstance(other, AbstractMesh):
return False
return (self.shape_tuple == other.shape_tuple and
self._axis_types_tuple == other._axis_types_tuple)
@ -451,9 +453,9 @@ class AbstractMesh:
def _name_to_type(self):
return axis_names_to_types(self.axis_types)
@functools.cached_property
@property
def size(self):
return math.prod(self._axis_sizes) if self._axis_sizes else 0
return self._size
@functools.cached_property
def shape(self):

View File

@ -19,7 +19,6 @@ from collections import OrderedDict
from collections.abc import Mapping, Sequence
import dataclasses
import functools
import itertools
import math
from typing import Any, NamedTuple, Union, cast
@ -1090,6 +1089,9 @@ get_single_pspec = lambda p: array_mapping_to_axis_resources(
class ParsedPartitionSpec:
__slots__ = ('_user_spec', 'partitions')
_user_spec: PartitionSpec | None
partitions: tuple[tuple[MeshAxisName, ...] | None, ...]
def __init__(self, user_spec, partitions):
self._user_spec = user_spec
# None in partitions represents unconstrained dim.
@ -1111,7 +1113,12 @@ class ParsedPartitionSpec:
return ParsedPartitionSpec(None, new_partitions)
@classmethod
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
def from_user_input(
cls,
entry: PartitionSpec | None,
arg_name: str,
allow_unconstrained_dims: bool = False,
) -> ParsedPartitionSpec:
if entry is None:
return cls(entry, ())
if not isinstance(entry, PartitionSpec):
@ -1157,9 +1164,10 @@ class ParsedPartitionSpec:
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
if parsed_pspec is None:
parsed_pspec = prepare_axis_resources(
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
spec = PartitionSpec() if spec is None else spec
parsed_pspec = ParsedPartitionSpec.from_user_input(
spec, "NamedSharding spec", allow_unconstrained_dims=True)
_check_unique_resources(parsed_pspec, "NamedSharding spec")
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
_check_axis_type_consistency(mesh, parsed_pspec)
return parsed_pspec
@ -1182,30 +1190,34 @@ def prepare_axis_resources(axis_resources, arg_name,
'allowed.')
new_entries.append(entry)
else:
new_entries.append(ParsedPartitionSpec.from_user_input(
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
parsed_pspec = ParsedPartitionSpec.from_user_input(
entry, what, allow_unconstrained_dims=allow_unconstrained_dims)
_check_unique_resources(parsed_pspec, arg_name)
new_entries.append(parsed_pspec)
_check_unique_resources(new_entries, arg_name)
return tree_util.tree_unflatten(treedef, new_entries)
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, jsharding.Sharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = collections.Counter(
itertools.chain.from_iterable(constrained_dims))
if not resource_counts: continue
if resource_counts.most_common(1)[0][1] > 1:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
if multiple_uses:
raise ValueError(
f'A single {arg_name} specification can map every mesh axis to at'
' most one positional dimension, but'
f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
f' for {mesh_lib.show_axes(multiple_uses)}')
def _check_unique_resources(
arg_axis_resources: ParsedPartitionSpec, arg_name: str
) -> None:
resource_counts: dict[MeshAxisName, int] = {}
duplicate = False
for d in arg_axis_resources:
if d is not None:
for resource in d:
count = resource_counts.get(resource, 0)
if count > 0:
duplicate = True
resource_counts[resource] = count + 1
if duplicate:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
raise ValueError(
f'A single {arg_name} specification can map every mesh axis to at'
' most one positional dimension, but'
f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
f' for {mesh_lib.show_axes(multiple_uses)}')
# Axis environments