rocm_jax/jax/_src/custom_partitioning_sharding_rule.py
Bixia Zheng 0c3de93b79 [jax:custom_partitioning] Support SdyShardingRule with multiple leading
batching dimension groups.

Previously, we allow the use of ellipsis ... in the Einsum like notation to
represent leading batching dimensions in one group of operands and results. We
now allow the use of ellipsis optionally followed by a single digit, such as
...2, to represent leading batching dimensions for multiple groups of operands
and results.

Add tests.

PiperOrigin-RevId: 718875251
2025-01-23 08:20:38 -08:00

476 lines
18 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements SdyShardingRule."""
from collections import OrderedDict
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
# A single character replacement for ... to simplify parsing.
BATCHING: str = ""
# A prefix for names of batching dimension factors, used for expanding the
# leading ... into factors.
_BATCHING_DIM_FACTOR_PREFIX = "?"
# A Jax value in general corresponds to an ir.Type or a tuple of ir.Types.
IrTypes = ir.Type | tuple[ir.Type, ...]
def _check_factor(factor:str):
"""Validates a factor.
A factor is a string starting with a letter and containing only letters,
digits, or underscores.
"""
if not factor[0].isalpha():
raise ValueError(f"Factor names have to start with a letter, but got '{factor[0]}'")
for char in factor[1:]:
if char != "_" and not char.isdigit() and not char.isalpha():
raise ValueError(f"Unknown character '{char}'")
def _is_batching(factor: str) -> bool:
"""Checks if a factor is a representation for leading batching dimensions.
Leading batching dimensions is represented by a factor containing ... and
optionally followed by a digit, and ... is equivalent to ...0.
"""
if len(factor) < 1 or factor[0] != BATCHING:
return False
return len(factor) == 1 or factor[1:].isdigit()
def _get_batching_group(factor: str) -> str:
"""Extracts the batching group from a factor for leading batching dimensions."""
return factor[1:] if len(factor) > 1 else "0"
class CompoundFactor(tuple):
"""Describes the factors for a compound factor.
A compound factor should contain at least two factors, e.g.
* CompoundFactor('b', 'c').
"""
def __init__(self, *factors):
if len(factors) < 2:
raise ValueError("A compound factor should contain at least two factors")
for factor in factors:
if not isinstance(factor, str):
raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}")
if _is_batching(factor):
raise ValueError("Ellipsis can't be used in a compound factor")
else:
_check_factor(factor)
def __new__(cls, *factors):
return tuple.__new__(CompoundFactor, factors)
class ArrayMapping(tuple):
"""Describes the factors for an operand or result.
Each element is either a factor or a CompoundFactor. A leading element can
also be BATCHING, which represents batching dimensions. examples:
* ArrayMapping('a')
* ArrayMapping('b', 'c')
* ArrayMapping(CompoundFactor('b', 'c'), 'd')
* ArrayMapping(BATCHING, CompoundFactor('b', 'c'), 'd')
"""
def __init__(self, *dim_mappings):
for i, d in enumerate(dim_mappings):
if not isinstance(d, str) and not isinstance(d, CompoundFactor):
raise ValueError(
"Each element of ArrayMapping must be a str or CompoundFactor, but"
f" got {type(d)}")
if isinstance(d, str):
if _is_batching(d):
if i != 0:
raise ValueError("Ellipsis can only be used at the beginning of a dimension")
else:
_check_factor(d)
def __new__(cls, *dim_mappings):
return tuple.__new__(ArrayMapping, dim_mappings)
class SdyShardingRule:
"""Represents a Shardy sharding rule.
An SdyShardingRule contains the ArrayMappings for operands and results, and an
optional list of factor sizes. A factor is a name used in the ArrayMappings.
If a factor is only used in CompoundFactors, its size must be specified.
"""
operand_mappings: tuple[ArrayMapping, ...]
result_mappings: tuple[ArrayMapping, ...]
factor_sizes: dict[str, int]
def __init__(self, operand_mappings: tuple[ArrayMapping, ...],
result_mappings: tuple[ArrayMapping, ...], **factor_sizes):
# Find all factors and mark whether their size can be inferred.
factors_inferrable = dict()
for value in operand_mappings + result_mappings:
for dim in value:
if isinstance(dim, str):
factors_inferrable[dim] = True
else:
for factor in dim:
if factor not in factors_inferrable.keys():
factors_inferrable[factor] = False
# Check that factors in factor_sizes are used in the rule.
for factor in factor_sizes:
if factor not in factors_inferrable:
raise ValueError(
f"Factor {factor} is not used in the rule, but size is provided")
# Check that factors that are used for a whole dimension aren't in
# factor_sizes and factors that are never used for a whole dimension are
# in factor_sizes.
for factor, inferrable in factors_inferrable.items():
if factor not in factor_sizes and not inferrable:
raise ValueError(
f"Factor {factor} is only used in compound factors; must specify"
" its size")
if factor in factor_sizes and inferrable:
raise ValueError(
f"Factor {factor} represents a whole dimension; do not specify its"
" size")
self.operand_mappings = operand_mappings
self.result_mappings = result_mappings
self.factor_sizes = factor_sizes
def __str__(self):
return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})"
def _get_batching_dim_factor_name(batch_group: str,batch_dim_order : int):
"""Constructs a factor name for a batching dimension.
We expand the leading ... into factors representing the batching dimensions
to support building the MLIR representation for the sharding rule. For this
reason, we construct a factor name that won't be used by users for the
batching dimensions.
"""
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_group}_{batch_dim_order}"
def _parse_values(
rule: str,
) -> tuple[ArrayMapping, ...]:
"""Parses the LHS or RHS of an Einsum notation like string.
Converts each operand or result in the Einsum notation like string to a tuple
of ArrayMapping. This very closely follows how einops parses their rules in
einops/parsing.py.
Args:
rule: The Einsum notation for the operands or results of an operation.
Returns:
The tuple of ArrayMapping.
Raises:
ValueError: If the rule is not balanced or contains unknown characters.
"""
# Remove unnecessary spaces in the rule to simplify the parsing process.
words = rule.split()
rule = " ".join(words)
# Similar to einops rules, an empty LHS/RHS has a single scalar value.
if not rule:
return (ArrayMapping(),)
all_values = []
# Represent all dimensions of an value. When an value[0]==BATCHING, the
# value may have 0 or more leading dimensions.
value = []
current_factor = None
# A value of None indicates the current dimension is not a compound dimension,
# while a value of [] indicates that we have just started parsing a compound
# dimension.
current_compound_dim: list[str] | None = None
def add_factor(x):
if current_compound_dim is None:
value.append(x)
else:
current_compound_dim.append(x)
rule_len = len(rule)
rule_index = 0
while rule_index < rule_len:
char = rule[rule_index]
rule_index += 1
if char == BATCHING:
if (current_factor is not None or current_compound_dim is not None
or value):
raise ValueError(
"Ellipsis can only be used at the beginning of a dimension")
if rule_index < rule_len and rule[rule_index].isdigit():
batching_group_str = ""
while rule_index < rule_len and rule[rule_index].isdigit():
batching_group_str += rule[rule_index]
rule_index += 1
batching_group = str(int(batching_group_str))
else:
batching_group = "0"
add_factor(f"{BATCHING}{batching_group}")
continue
if char in "(), ":
if current_factor is not None:
add_factor(current_factor)
current_factor = None
if char == "(":
if current_compound_dim is not None:
raise ValueError(
"Compound factors should be one level, nested brackets are not"
" allowed")
current_compound_dim = []
elif char == ")":
if current_compound_dim is None:
raise ValueError("Brackets are not balanced")
if len(current_compound_dim) <= 1:
raise ValueError("Brackets should contain at least two factors")
value.append(CompoundFactor(*current_compound_dim))
current_compound_dim = None
elif char == ",":
all_values.append(ArrayMapping(*value))
value = []
elif char == "_" or char.isdigit() or char.isalpha():
if current_factor is None:
if str.isdigit(char):
raise ValueError(f"Factor names have to start with a letter, but got '{char}'")
current_factor = char
else:
current_factor += char
else:
raise ValueError(f"Unknown character '{char}'")
if current_compound_dim is not None:
raise ValueError(f"Brackets are not balanced in rule: '{rule}'")
if current_factor is not None:
add_factor(current_factor)
all_values.append(ArrayMapping(*value))
return tuple(all_values)
def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule:
"""Constructs a SdyShardingRule object from the Einsum notation like string.
This is done by verifying that the input Einsum notation like string and
with optional factor sizes represents a valid sharding rule and converting
it to an internal representation.
Args:
rule: The Einsum notation like string for an operation.
**factor_sizes: The optional factor sizes.
Raises:
ValueError: If there is any problem with the rule or factor_sizes.
"""
if not isinstance(rule, str):
raise TypeError(f"rule must be a str, but got {type(rule)}")
if not all(isinstance(size, int) for size in factor_sizes.values()):
raise TypeError(
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
# Replace ... with a single char to simplify parsing.
if BATCHING in rule:
raise ValueError(f"Unknown character '{BATCHING}'")
if "." in rule:
rule = rule.replace("...", BATCHING)
if "." in rule:
raise ValueError("Character '.' must be used inside ellipsis '...'")
try:
operands, results = rule.split("->")
except ValueError as e:
raise ValueError(f"There is no -> in rule: '{rule}'") from e
operand_mappings = _parse_values(operands)
result_mappings = _parse_values(results)
return SdyShardingRule(operand_mappings, result_mappings, **factor_sizes)
def sdy_sharding_rule_to_mlir(
rule: SdyShardingRule,
operand_types: list[IrTypes],
result_types: list[IrTypes],) -> ir.Attribute:
"""Builds the MLIR representation for the sharding rule.
This is done by verifying that the rule is consistent with the types of
the operation and converting the Einsum notation like string to
OpShardingRuleAttr.
"""
if len(rule.operand_mappings) != len(operand_types):
raise ValueError(
f"Sharding rule has {len(rule.operand_mappings)} operands, but the operation"
f" has {len(operand_types)} operands")
if len(rule.result_mappings) != len(result_types):
raise ValueError(
f"Sharding rule has {len(rule.result_mappings)} results, but the operation"
f" has {len(result_types)} results")
if not all(isinstance(t, ir.Type) for t in operand_types + result_types):
raise TypeError(
f"operand_types and result_types must be a list of ir.Type, but got"
f" {operand_types} and {result_types}")
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
types = operand_types + result_types
UNKNOWN = -1 # Representation for unknown factor size or factor index.
def get_message_for_value(i):
if i >= len(operand_types):
return f"{i - len(operand_types)}th result"
else:
return f"{i}th operand"
def get_rank_for_value(i):
return ir.ShapedType(types[i]).rank
def get_size_for_value_dim(i, j):
return ir.ShapedType(types[i]).shape[j]
def add_factor(factor, size):
"""Adds a factor to factors_to_indices_sizes.
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
if a factor is first used as in a compound factor and then used for a
whole dimension.
"""
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
if factor_index != UNKNOWN:
# Not the first time seeing the factor.
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
factor_or_batching_dim = (
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
else f"Batching dimension {factor[1:]}")
raise ValueError(
f"{factor_or_batching_dim} corresponds to two sizes:"
f" {factor_size} and {size}")
if size != UNKNOWN and factor_size == UNKNOWN:
factors_to_indices_sizes[factor] = [factor_index, size]
else:
# First time seeing the factor.
factor_index = len(factors_to_indices_sizes)
factors_to_indices_sizes[factor] = [factor_index, size]
def add_batching_dim_factor(batch_grp, batch_dim_order, factor_size):
add_factor(_get_batching_dim_factor_name(batch_grp, batch_dim_order), factor_size)
def build_dim_mapping_for_compound_factors(i, j, factors):
accumulated_size = 1
all_indices = []
for factor in factors:
factor_index, factor_size = factors_to_indices_sizes[factor]
accumulated_size *= factor_size
all_indices.append(factor_index)
dim_size = get_size_for_value_dim(i, j)
if accumulated_size != dim_size:
raise ValueError(
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
f" the size {accumulated_size} derived from the compound factors"
f" {factors}")
return sdy.DimMappingAttr.get(factor_indices=all_indices)
# Add factors and their sizes in the order they appear in the rule,
# including the batching dimensions represented by ellipsis.
batching_group_to_rank: dict[str, int] = {}
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
if value and _is_batching(value[0]):
batching_group = _get_batching_group(value[0])
value = value[1:]
else:
batching_group = None
rule_rank = len(value)
op_rank = get_rank_for_value(i)
# The number of dimensions represented by ellipsis.
current_batching_rank = 0
if batching_group is not None and op_rank >= rule_rank:
current_batching_rank = op_rank - rule_rank
if batching_group is not None:
ellipsis_rank = batching_group_to_rank.get(batching_group, None)
if ellipsis_rank is None:
ellipsis_rank = current_batching_rank
batching_group_to_rank[batching_group] = ellipsis_rank
elif ellipsis_rank != current_batching_rank:
raise ValueError(
"Ellipsis represents different number of leading dimensions"
f" {ellipsis_rank} and {current_batching_rank}")
rule_rank += current_batching_rank
if rule_rank != op_rank:
msg = get_message_for_value(i)
raise ValueError(
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
f" {msg} has rank {op_rank}")
for j in range(current_batching_rank):
add_batching_dim_factor(batching_group, j, get_size_for_value_dim(i, j))
for j, dim in enumerate(value):
if isinstance(dim, str):
add_factor(dim, get_size_for_value_dim(i, j + current_batching_rank))
else:
for factor in dim:
add_factor(factor, rule.factor_sizes.get(factor, UNKNOWN))
# Build the tensor mappings for each operand and result.
tensor_mappings = []
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
dim_mappings = []
if value and _is_batching(value[0]):
batching_group = _get_batching_group(value[0])
value = value[1:]
if batching_group in batching_group_to_rank:
# This type check error is not correct, disable it:
# Incompatible types in assignment (expression has type "int | None"
current_batching_rank = batching_group_to_rank.get(batching_group) # type: ignore
else:
raise ValueError("Unreachabled code")
else:
current_batching_rank = 0
batching_group = None
for j in range(current_batching_rank):
# This type check error is not correct, disable it:
# Argument 1 to "_get_batching_dim_factor_name" has incompatible type "str | None"; expected "str" [arg-type]
dim_mappings.append(
sdy.DimMappingAttr.get(factor_indices=[
factors_to_indices_sizes[_get_batching_dim_factor_name(batching_group, j)][0]])) # type: ignore
for j, dim in enumerate(value):
if isinstance(dim, str):
dim_mappings.append(
sdy.DimMappingAttr.get(
factor_indices=[factors_to_indices_sizes[dim][0]]))
else:
dim_mappings.append(
build_dim_mapping_for_compound_factors(
i, j + current_batching_rank, dim))
tensor_mappings.append(
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
return sdy.OpShardingRuleAttr.get(
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
operand_mappings=tensor_mappings[0:len(operand_types)],
result_mappings=tensor_mappings[len(operand_types):])