mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[jax:custom_partitioning] Support SdyShardingRule with multiple leading
batching dimension groups. Previously, we allow the use of ellipsis ... in the Einsum like notation to represent leading batching dimensions in one group of operands and results. We now allow the use of ellipsis optionally followed by a single digit, such as ...2, to represent leading batching dimensions for multiple groups of operands and results. Add tests. PiperOrigin-RevId: 718875251
This commit is contained in:
parent
6b95ad0a53
commit
0c3de93b79
@ -42,6 +42,20 @@ def _check_factor(factor:str):
|
||||
if char != "_" and not char.isdigit() and not char.isalpha():
|
||||
raise ValueError(f"Unknown character '{char}'")
|
||||
|
||||
def _is_batching(factor: str) -> bool:
|
||||
"""Checks if a factor is a representation for leading batching dimensions.
|
||||
|
||||
Leading batching dimensions is represented by a factor containing ... and
|
||||
optionally followed by a digit, and ... is equivalent to ...0.
|
||||
"""
|
||||
if len(factor) < 1 or factor[0] != BATCHING:
|
||||
return False
|
||||
return len(factor) == 1 or factor[1:].isdigit()
|
||||
|
||||
def _get_batching_group(factor: str) -> str:
|
||||
"""Extracts the batching group from a factor for leading batching dimensions."""
|
||||
return factor[1:] if len(factor) > 1 else "0"
|
||||
|
||||
class CompoundFactor(tuple):
|
||||
"""Describes the factors for a compound factor.
|
||||
|
||||
@ -54,7 +68,7 @@ class CompoundFactor(tuple):
|
||||
for factor in factors:
|
||||
if not isinstance(factor, str):
|
||||
raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}")
|
||||
if factor == BATCHING:
|
||||
if _is_batching(factor):
|
||||
raise ValueError("Ellipsis can't be used in a compound factor")
|
||||
else:
|
||||
_check_factor(factor)
|
||||
@ -80,7 +94,7 @@ class ArrayMapping(tuple):
|
||||
"Each element of ArrayMapping must be a str or CompoundFactor, but"
|
||||
f" got {type(d)}")
|
||||
if isinstance(d, str):
|
||||
if d == BATCHING:
|
||||
if _is_batching(d):
|
||||
if i != 0:
|
||||
raise ValueError("Ellipsis can only be used at the beginning of a dimension")
|
||||
else:
|
||||
@ -141,7 +155,7 @@ class SdyShardingRule:
|
||||
return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})"
|
||||
|
||||
|
||||
def _get_batching_dim_factor_name(batch_dim_order : int):
|
||||
def _get_batching_dim_factor_name(batch_group: str,batch_dim_order : int):
|
||||
"""Constructs a factor name for a batching dimension.
|
||||
|
||||
We expand the leading ... into factors representing the batching dimensions
|
||||
@ -149,7 +163,7 @@ def _get_batching_dim_factor_name(batch_dim_order : int):
|
||||
reason, we construct a factor name that won't be used by users for the
|
||||
batching dimensions.
|
||||
"""
|
||||
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}"
|
||||
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_group}_{batch_dim_order}"
|
||||
|
||||
def _parse_values(
|
||||
rule: str,
|
||||
@ -194,13 +208,26 @@ def _parse_values(
|
||||
else:
|
||||
current_compound_dim.append(x)
|
||||
|
||||
for char in rule:
|
||||
rule_len = len(rule)
|
||||
rule_index = 0
|
||||
while rule_index < rule_len:
|
||||
char = rule[rule_index]
|
||||
rule_index += 1
|
||||
if char == BATCHING:
|
||||
if (current_factor is not None or current_compound_dim is not None
|
||||
or value):
|
||||
raise ValueError(
|
||||
"Ellipsis can only be used at the beginning of a dimension")
|
||||
add_factor(BATCHING)
|
||||
if rule_index < rule_len and rule[rule_index].isdigit():
|
||||
batching_group_str = ""
|
||||
while rule_index < rule_len and rule[rule_index].isdigit():
|
||||
batching_group_str += rule[rule_index]
|
||||
rule_index += 1
|
||||
batching_group = str(int(batching_group_str))
|
||||
else:
|
||||
batching_group = "0"
|
||||
|
||||
add_factor(f"{BATCHING}{batching_group}")
|
||||
continue
|
||||
if char in "(), ":
|
||||
if current_factor is not None:
|
||||
@ -342,9 +369,8 @@ def sdy_sharding_rule_to_mlir(
|
||||
factor_index = len(factors_to_indices_sizes)
|
||||
factors_to_indices_sizes[factor] = [factor_index, size]
|
||||
|
||||
def add_batching_dim_factor(batch_dim_order, factor_size):
|
||||
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
|
||||
add_factor(ellipsis_batch_dim_name, factor_size)
|
||||
def add_batching_dim_factor(batch_grp, batch_dim_order, factor_size):
|
||||
add_factor(_get_batching_dim_factor_name(batch_grp, batch_dim_order), factor_size)
|
||||
|
||||
def build_dim_mapping_for_compound_factors(i, j, factors):
|
||||
accumulated_size = 1
|
||||
@ -365,23 +391,25 @@ def sdy_sharding_rule_to_mlir(
|
||||
|
||||
# Add factors and their sizes in the order they appear in the rule,
|
||||
# including the batching dimensions represented by ellipsis.
|
||||
ellipsis_rank = None
|
||||
batching_group_to_rank: dict[str, int] = {}
|
||||
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
|
||||
value = tuple(mapping)
|
||||
if value and value[0] == BATCHING:
|
||||
has_batching = True
|
||||
if value and _is_batching(value[0]):
|
||||
batching_group = _get_batching_group(value[0])
|
||||
value = value[1:]
|
||||
else:
|
||||
has_batching = False
|
||||
batching_group = None
|
||||
rule_rank = len(value)
|
||||
op_rank = get_rank_for_value(i)
|
||||
# The number of dimensions represented by ellipsis.
|
||||
current_batching_rank = 0
|
||||
if has_batching and op_rank >= rule_rank:
|
||||
if batching_group is not None and op_rank >= rule_rank:
|
||||
current_batching_rank = op_rank - rule_rank
|
||||
if has_batching:
|
||||
if batching_group is not None:
|
||||
ellipsis_rank = batching_group_to_rank.get(batching_group, None)
|
||||
if ellipsis_rank is None:
|
||||
ellipsis_rank = current_batching_rank
|
||||
batching_group_to_rank[batching_group] = ellipsis_rank
|
||||
elif ellipsis_rank != current_batching_rank:
|
||||
raise ValueError(
|
||||
"Ellipsis represents different number of leading dimensions"
|
||||
@ -394,7 +422,7 @@ def sdy_sharding_rule_to_mlir(
|
||||
f" {msg} has rank {op_rank}")
|
||||
|
||||
for j in range(current_batching_rank):
|
||||
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
|
||||
add_batching_dim_factor(batching_group, j, get_size_for_value_dim(i, j))
|
||||
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
@ -408,20 +436,25 @@ def sdy_sharding_rule_to_mlir(
|
||||
for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings):
|
||||
value = tuple(mapping)
|
||||
dim_mappings = []
|
||||
|
||||
if value and value[0] == BATCHING:
|
||||
if value and _is_batching(value[0]):
|
||||
batching_group = _get_batching_group(value[0])
|
||||
value = value[1:]
|
||||
if ellipsis_rank is None:
|
||||
current_batching_rank = 0
|
||||
if batching_group in batching_group_to_rank:
|
||||
# This type check error is not correct, disable it:
|
||||
# Incompatible types in assignment (expression has type "int | None"
|
||||
current_batching_rank = batching_group_to_rank.get(batching_group) # type: ignore
|
||||
else:
|
||||
current_batching_rank = ellipsis_rank
|
||||
raise ValueError("Unreachabled code")
|
||||
else:
|
||||
current_batching_rank = 0
|
||||
batching_group = None
|
||||
|
||||
for j in range(current_batching_rank):
|
||||
# This type check error is not correct, disable it:
|
||||
# Argument 1 to "_get_batching_dim_factor_name" has incompatible type "str | None"; expected "str" [arg-type]
|
||||
dim_mappings.append(
|
||||
sdy.DimMappingAttr.get(factor_indices=[
|
||||
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
|
||||
factors_to_indices_sizes[_get_batching_dim_factor_name(batching_group, j)][0]])) # type: ignore
|
||||
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
|
@ -50,8 +50,8 @@ class SdyShardingRuleTest(jtu.JaxTestCase):
|
||||
ArrayMapping("i_j", BATCHING)
|
||||
|
||||
def test_value_mapping_str(self):
|
||||
v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k")
|
||||
self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')")
|
||||
v = ArrayMapping(f"{BATCHING}2", "m", CompoundFactor("i", "j"), "k")
|
||||
self.assertEqual(str(v), f"('{BATCHING}2', 'm', ('i', 'j'), 'k')")
|
||||
|
||||
def test_sdy_sharding_rule_factor_size_not_used(self):
|
||||
with self.assertRaisesRegex(ValueError, "Factor k is not used"):
|
||||
@ -158,17 +158,18 @@ class StrToSdyShardingRuleTest(jtu.JaxTestCase):
|
||||
str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})")
|
||||
|
||||
def test_sharding_rule_factor_elementwise_add(self):
|
||||
rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j")
|
||||
# An ellipsis without a number ... is treated as the same as ...0.
|
||||
rule = str_to_sdy_sharding_rule("...0 i j, ...1 i j -> ...i j")
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i',"
|
||||
"SdyShardingRule((('…0', 'i', 'j'), ('…1', 'i', 'j')), (('…0', 'i',"
|
||||
" 'j'),), {})")
|
||||
|
||||
def test_sharding_rule_factor_vector_scalar_add(self):
|
||||
rule = str_to_sdy_sharding_rule("...i, -> ...i")
|
||||
rule = str_to_sdy_sharding_rule("...87 i, -> ...87 i")
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})")
|
||||
"SdyShardingRule((('…87', 'i'), ()), (('…87', 'i'),), {})")
|
||||
|
||||
def test_sharding_rule_factor_reshape_combining(self):
|
||||
rule = str_to_sdy_sharding_rule("i j -> (i j)")
|
||||
@ -316,7 +317,7 @@ class SdyShardingRuleConversionTest(jtu.JaxTestCase):
|
||||
rule = str_to_sdy_sharding_rule("..., ... -> ...")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Batching dimension 1 corresponds to two sizes: 32 and 64"):
|
||||
"Batching dimension 0_1 corresponds to two sizes: 32 and 64"):
|
||||
sdy_sharding_rule_to_mlir(rule,
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,],)
|
||||
@ -464,5 +465,22 @@ class SdyShardingRuleConversionTest(jtu.JaxTestCase):
|
||||
"#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>")
|
||||
|
||||
|
||||
def test_conversion_multiple_batching_groups(self):
|
||||
opnd0 = self.create_tensor_value((4, 5, 16, 32))
|
||||
opnd1 = self.create_tensor_value((6, 7, 8, 32, 16))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((4, 5, 32, 16))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = str_to_sdy_sharding_rule("... j i, ...1 i j -> ...i j")
|
||||
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
self.assertEqual(
|
||||
str(mlir_rule),
|
||||
"#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user