[jax:custom_partitioning] Support SdyShardingRule with multiple leading

batching dimension groups.

Previously, we allow the use of ellipsis ... in the Einsum like notation to
represent leading batching dimensions in one group of operands and results. We
now allow the use of ellipsis optionally followed by a single digit, such as
...2, to represent leading batching dimensions for multiple groups of operands
and results.

Add tests.

PiperOrigin-RevId: 718875251
This commit is contained in:
Bixia Zheng 2025-01-23 08:20:04 -08:00 committed by jax authors
parent 6b95ad0a53
commit 0c3de93b79
2 changed files with 80 additions and 29 deletions

View File

@ -42,6 +42,20 @@ def _check_factor(factor:str):
if char != "_" and not char.isdigit() and not char.isalpha():
raise ValueError(f"Unknown character '{char}'")
def _is_batching(factor: str) -> bool:
"""Checks if a factor is a representation for leading batching dimensions.
Leading batching dimensions is represented by a factor containing ... and
optionally followed by a digit, and ... is equivalent to ...0.
"""
if len(factor) < 1 or factor[0] != BATCHING:
return False
return len(factor) == 1 or factor[1:].isdigit()
def _get_batching_group(factor: str) -> str:
"""Extracts the batching group from a factor for leading batching dimensions."""
return factor[1:] if len(factor) > 1 else "0"
class CompoundFactor(tuple):
"""Describes the factors for a compound factor.
@ -54,7 +68,7 @@ class CompoundFactor(tuple):
for factor in factors:
if not isinstance(factor, str):
raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}")
if factor == BATCHING:
if _is_batching(factor):
raise ValueError("Ellipsis can't be used in a compound factor")
else:
_check_factor(factor)
@ -80,7 +94,7 @@ class ArrayMapping(tuple):
"Each element of ArrayMapping must be a str or CompoundFactor, but"
f" got {type(d)}")
if isinstance(d, str):
if d == BATCHING:
if _is_batching(d):
if i != 0:
raise ValueError("Ellipsis can only be used at the beginning of a dimension")
else:
@ -141,7 +155,7 @@ class SdyShardingRule:
return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})"
def _get_batching_dim_factor_name(batch_dim_order : int):
def _get_batching_dim_factor_name(batch_group: str,batch_dim_order : int):
"""Constructs a factor name for a batching dimension.
We expand the leading ... into factors representing the batching dimensions
@ -149,7 +163,7 @@ def _get_batching_dim_factor_name(batch_dim_order : int):
reason, we construct a factor name that won't be used by users for the
batching dimensions.
"""
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}"
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_group}_{batch_dim_order}"
def _parse_values(
rule: str,
@ -194,13 +208,26 @@ def _parse_values(
else:
current_compound_dim.append(x)
for char in rule:
rule_len = len(rule)
rule_index = 0
while rule_index < rule_len:
char = rule[rule_index]
rule_index += 1
if char == BATCHING:
if (current_factor is not None or current_compound_dim is not None
or value):
raise ValueError(
"Ellipsis can only be used at the beginning of a dimension")
add_factor(BATCHING)
if rule_index < rule_len and rule[rule_index].isdigit():
batching_group_str = ""
while rule_index < rule_len and rule[rule_index].isdigit():
batching_group_str += rule[rule_index]
rule_index += 1
batching_group = str(int(batching_group_str))
else:
batching_group = "0"
add_factor(f"{BATCHING}{batching_group}")
continue
if char in "(), ":
if current_factor is not None:
@ -342,9 +369,8 @@ def sdy_sharding_rule_to_mlir(
factor_index = len(factors_to_indices_sizes)
factors_to_indices_sizes[factor] = [factor_index, size]
def add_batching_dim_factor(batch_dim_order, factor_size):
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
add_factor(ellipsis_batch_dim_name, factor_size)
def add_batching_dim_factor(batch_grp, batch_dim_order, factor_size):
add_factor(_get_batching_dim_factor_name(batch_grp, batch_dim_order), factor_size)
def build_dim_mapping_for_compound_factors(i, j, factors):
accumulated_size = 1
@ -365,23 +391,25 @@ def sdy_sharding_rule_to_mlir(
# Add factors and their sizes in the order they appear in the rule,
# including the batching dimensions represented by ellipsis.
ellipsis_rank = None
batching_group_to_rank: dict[str, int] = {}
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
if value and value[0] == BATCHING:
has_batching = True
if value and _is_batching(value[0]):
batching_group = _get_batching_group(value[0])
value = value[1:]
else:
has_batching = False
batching_group = None
rule_rank = len(value)
op_rank = get_rank_for_value(i)
# The number of dimensions represented by ellipsis.
current_batching_rank = 0
if has_batching and op_rank >= rule_rank:
if batching_group is not None and op_rank >= rule_rank:
current_batching_rank = op_rank - rule_rank
if has_batching:
if batching_group is not None:
ellipsis_rank = batching_group_to_rank.get(batching_group, None)
if ellipsis_rank is None:
ellipsis_rank = current_batching_rank
batching_group_to_rank[batching_group] = ellipsis_rank
elif ellipsis_rank != current_batching_rank:
raise ValueError(
"Ellipsis represents different number of leading dimensions"
@ -394,7 +422,7 @@ def sdy_sharding_rule_to_mlir(
f" {msg} has rank {op_rank}")
for j in range(current_batching_rank):
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
add_batching_dim_factor(batching_group, j, get_size_for_value_dim(i, j))
for j, dim in enumerate(value):
if isinstance(dim, str):
@ -408,20 +436,25 @@ def sdy_sharding_rule_to_mlir(
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
value = tuple(mapping)
dim_mappings = []
if value and value[0] == BATCHING:
if value and _is_batching(value[0]):
batching_group = _get_batching_group(value[0])
value = value[1:]
if ellipsis_rank is None:
current_batching_rank = 0
if batching_group in batching_group_to_rank:
# This type check error is not correct, disable it:
# Incompatible types in assignment (expression has type "int | None"
current_batching_rank = batching_group_to_rank.get(batching_group) # type: ignore
else:
current_batching_rank = ellipsis_rank
raise ValueError("Unreachabled code")
else:
current_batching_rank = 0
batching_group = None
for j in range(current_batching_rank):
# This type check error is not correct, disable it:
# Argument 1 to "_get_batching_dim_factor_name" has incompatible type "str | None"; expected "str" [arg-type]
dim_mappings.append(
sdy.DimMappingAttr.get(factor_indices=[
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
factors_to_indices_sizes[_get_batching_dim_factor_name(batching_group, j)][0]])) # type: ignore
for j, dim in enumerate(value):
if isinstance(dim, str):

View File

@ -50,8 +50,8 @@ class SdyShardingRuleTest(jtu.JaxTestCase):
ArrayMapping("i_j", BATCHING)
def test_value_mapping_str(self):
v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k")
self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')")
v = ArrayMapping(f"{BATCHING}2", "m", CompoundFactor("i", "j"), "k")
self.assertEqual(str(v), f"('{BATCHING}2', 'm', ('i', 'j'), 'k')")
def test_sdy_sharding_rule_factor_size_not_used(self):
with self.assertRaisesRegex(ValueError, "Factor k is not used"):
@ -158,17 +158,18 @@ class StrToSdyShardingRuleTest(jtu.JaxTestCase):
str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})")
def test_sharding_rule_factor_elementwise_add(self):
rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j")
# An ellipsis without a number ... is treated as the same as ...0.
rule = str_to_sdy_sharding_rule("...0 i j, ...1 i j -> ...i j")
self.assertEqual(
str(rule),
"SdyShardingRule((('', 'i', 'j'), ('', 'i', 'j')), (('', 'i',"
"SdyShardingRule((('0', 'i', 'j'), ('1', 'i', 'j')), (('0', 'i',"
" 'j'),), {})")
def test_sharding_rule_factor_vector_scalar_add(self):
rule = str_to_sdy_sharding_rule("...i, -> ...i")
rule = str_to_sdy_sharding_rule("...87 i, -> ...87 i")
self.assertEqual(
str(rule),
"SdyShardingRule((('', 'i'), ()), (('', 'i'),), {})")
"SdyShardingRule((('87', 'i'), ()), (('87', 'i'),), {})")
def test_sharding_rule_factor_reshape_combining(self):
rule = str_to_sdy_sharding_rule("i j -> (i j)")
@ -316,7 +317,7 @@ class SdyShardingRuleConversionTest(jtu.JaxTestCase):
rule = str_to_sdy_sharding_rule("..., ... -> ...")
with self.assertRaisesRegex(
ValueError,
"Batching dimension 1 corresponds to two sizes: 32 and 64"):
"Batching dimension 0_1 corresponds to two sizes: 32 and 64"):
sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,],)
@ -464,5 +465,22 @@ class SdyShardingRuleConversionTest(jtu.JaxTestCase):
"#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>")
def test_conversion_multiple_batching_groups(self):
opnd0 = self.create_tensor_value((4, 5, 16, 32))
opnd1 = self.create_tensor_value((6, 7, 8, 32, 16))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((4, 5, 32, 16))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = str_to_sdy_sharding_rule("... j i, ...1 i j -> ...i j")
mlir_rule = sdy_sharding_rule_to_mlir(rule,
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>")
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())