mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Speed up NamedSharding construction.
* Compute the size of a mesh eagerly. We're almost always going to need this, because NamedSharding's constructor asks for it. * Speed up mesh equality. It's likely we have only one mesh, and the identity equality test will hit. Do it first. * don't call _prepare_axis_resources in ParsedPartitionSpec construction. This does a bunch of pointless tree flattening and list manipulation but we know we have exactly one PartitionSpec and can directly do the check we need, which is _check_unique_resources. * only call _check_unique_resources on PartitionSpecs; it's easy to avoid doing it in other cases and then we don't need a bunch of isinstance checks. * avoid use of collections.Counter when checking for unique resources. collections.Counter has a surprisingly slow isinstance test. PiperOrigin-RevId: 724431847
This commit is contained in:
parent
cd0753751c
commit
f21b0f03b4
@ -228,6 +228,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
self.axis_names = axis_names
|
||||
self.axis_types = axis_types
|
||||
self._axis_types_tuple = axis_types_tuple
|
||||
self._size = math.prod(self.shape.values()) if self.devices.ndim else 0
|
||||
_mesh_object_dict[key] = self
|
||||
return self
|
||||
|
||||
@ -236,12 +237,12 @@ class Mesh(contextlib.ContextDecorator):
|
||||
{'axis_types': self.axis_types})
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Mesh):
|
||||
return False
|
||||
# This is a performance optimization. Comparing thousands of devices
|
||||
# can be expensive.
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
if not isinstance(other, Mesh):
|
||||
return False
|
||||
return (self.axis_names == other.axis_names and
|
||||
self.devices.shape == other.devices.shape and
|
||||
self._axis_types_tuple == other._axis_types_tuple and
|
||||
@ -306,7 +307,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return math.prod(self.shape.values()) if self.devices.ndim else 0
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
@ -413,6 +414,7 @@ class AbstractMesh:
|
||||
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
|
||||
else:
|
||||
self._axis_names, self._axis_sizes = (), ()
|
||||
self._size = math.prod(self._axis_sizes) if self._axis_sizes else 0
|
||||
self.axis_types = ({AxisTypes.Auto: self._axis_names}
|
||||
if axis_types is None else axis_types)
|
||||
self._axis_types_tuple = to_axis_types_tuple(self.axis_types)
|
||||
@ -426,10 +428,10 @@ class AbstractMesh:
|
||||
return hash((self.shape_tuple, self._axis_types_tuple))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, AbstractMesh):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
if not isinstance(other, AbstractMesh):
|
||||
return False
|
||||
return (self.shape_tuple == other.shape_tuple and
|
||||
self._axis_types_tuple == other._axis_types_tuple)
|
||||
|
||||
@ -451,9 +453,9 @@ class AbstractMesh:
|
||||
def _name_to_type(self):
|
||||
return axis_names_to_types(self.axis_types)
|
||||
|
||||
@functools.cached_property
|
||||
@property
|
||||
def size(self):
|
||||
return math.prod(self._axis_sizes) if self._axis_sizes else 0
|
||||
return self._size
|
||||
|
||||
@functools.cached_property
|
||||
def shape(self):
|
||||
|
@ -19,7 +19,6 @@ from collections import OrderedDict
|
||||
from collections.abc import Mapping, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, NamedTuple, Union, cast
|
||||
|
||||
@ -1090,6 +1089,9 @@ get_single_pspec = lambda p: array_mapping_to_axis_resources(
|
||||
class ParsedPartitionSpec:
|
||||
__slots__ = ('_user_spec', 'partitions')
|
||||
|
||||
_user_spec: PartitionSpec | None
|
||||
partitions: tuple[tuple[MeshAxisName, ...] | None, ...]
|
||||
|
||||
def __init__(self, user_spec, partitions):
|
||||
self._user_spec = user_spec
|
||||
# None in partitions represents unconstrained dim.
|
||||
@ -1111,7 +1113,12 @@ class ParsedPartitionSpec:
|
||||
return ParsedPartitionSpec(None, new_partitions)
|
||||
|
||||
@classmethod
|
||||
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
|
||||
def from_user_input(
|
||||
cls,
|
||||
entry: PartitionSpec | None,
|
||||
arg_name: str,
|
||||
allow_unconstrained_dims: bool = False,
|
||||
) -> ParsedPartitionSpec:
|
||||
if entry is None:
|
||||
return cls(entry, ())
|
||||
if not isinstance(entry, PartitionSpec):
|
||||
@ -1157,9 +1164,10 @@ class ParsedPartitionSpec:
|
||||
|
||||
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
||||
if parsed_pspec is None:
|
||||
parsed_pspec = prepare_axis_resources(
|
||||
PartitionSpec() if spec is None else spec,
|
||||
"NamedSharding spec", allow_unconstrained_dims=True)
|
||||
spec = PartitionSpec() if spec is None else spec
|
||||
parsed_pspec = ParsedPartitionSpec.from_user_input(
|
||||
spec, "NamedSharding spec", allow_unconstrained_dims=True)
|
||||
_check_unique_resources(parsed_pspec, "NamedSharding spec")
|
||||
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
||||
_check_axis_type_consistency(mesh, parsed_pspec)
|
||||
return parsed_pspec
|
||||
@ -1182,30 +1190,34 @@ def prepare_axis_resources(axis_resources, arg_name,
|
||||
'allowed.')
|
||||
new_entries.append(entry)
|
||||
else:
|
||||
new_entries.append(ParsedPartitionSpec.from_user_input(
|
||||
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
|
||||
parsed_pspec = ParsedPartitionSpec.from_user_input(
|
||||
entry, what, allow_unconstrained_dims=allow_unconstrained_dims)
|
||||
_check_unique_resources(parsed_pspec, arg_name)
|
||||
new_entries.append(parsed_pspec)
|
||||
|
||||
_check_unique_resources(new_entries, arg_name)
|
||||
return tree_util.tree_unflatten(treedef, new_entries)
|
||||
|
||||
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, jsharding.Sharding)):
|
||||
continue
|
||||
constrained_dims = [d for d in arg_axis_resources if d is not None]
|
||||
resource_counts = collections.Counter(
|
||||
itertools.chain.from_iterable(constrained_dims))
|
||||
if not resource_counts: continue
|
||||
if resource_counts.most_common(1)[0][1] > 1:
|
||||
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
||||
if multiple_uses:
|
||||
raise ValueError(
|
||||
f'A single {arg_name} specification can map every mesh axis to at'
|
||||
' most one positional dimension, but'
|
||||
f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
|
||||
f' for {mesh_lib.show_axes(multiple_uses)}')
|
||||
def _check_unique_resources(
|
||||
arg_axis_resources: ParsedPartitionSpec, arg_name: str
|
||||
) -> None:
|
||||
resource_counts: dict[MeshAxisName, int] = {}
|
||||
duplicate = False
|
||||
for d in arg_axis_resources:
|
||||
if d is not None:
|
||||
for resource in d:
|
||||
count = resource_counts.get(resource, 0)
|
||||
if count > 0:
|
||||
duplicate = True
|
||||
resource_counts[resource] = count + 1
|
||||
if duplicate:
|
||||
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
||||
raise ValueError(
|
||||
f'A single {arg_name} specification can map every mesh axis to at'
|
||||
' most one positional dimension, but'
|
||||
f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
|
||||
f' for {mesh_lib.show_axes(multiple_uses)}')
|
||||
|
||||
|
||||
# Axis environments
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user