diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index 68294b591..5e2e5f4e0 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -42,6 +42,20 @@ def _check_factor(factor:str): if char != "_" and not char.isdigit() and not char.isalpha(): raise ValueError(f"Unknown character '{char}'") +def _is_batching(factor: str) -> bool: + """Checks if a factor is a representation for leading batching dimensions. + + Leading batching dimensions is represented by a factor containing ... and + optionally followed by a digit, and ... is equivalent to ...0. + """ + if len(factor) < 1 or factor[0] != BATCHING: + return False + return len(factor) == 1 or factor[1:].isdigit() + +def _get_batching_group(factor: str) -> str: + """Extracts the batching group from a factor for leading batching dimensions.""" + return factor[1:] if len(factor) > 1 else "0" + class CompoundFactor(tuple): """Describes the factors for a compound factor. @@ -54,7 +68,7 @@ class CompoundFactor(tuple): for factor in factors: if not isinstance(factor, str): raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}") - if factor == BATCHING: + if _is_batching(factor): raise ValueError("Ellipsis can't be used in a compound factor") else: _check_factor(factor) @@ -80,7 +94,7 @@ class ArrayMapping(tuple): "Each element of ArrayMapping must be a str or CompoundFactor, but" f" got {type(d)}") if isinstance(d, str): - if d == BATCHING: + if _is_batching(d): if i != 0: raise ValueError("Ellipsis can only be used at the beginning of a dimension") else: @@ -141,7 +155,7 @@ class SdyShardingRule: return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})" -def _get_batching_dim_factor_name(batch_dim_order : int): +def _get_batching_dim_factor_name(batch_group: str,batch_dim_order : int): """Constructs a factor name for a batching dimension. We expand the leading ... into factors representing the batching dimensions @@ -149,7 +163,7 @@ def _get_batching_dim_factor_name(batch_dim_order : int): reason, we construct a factor name that won't be used by users for the batching dimensions. """ - return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}" + return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_group}_{batch_dim_order}" def _parse_values( rule: str, @@ -194,13 +208,26 @@ def _parse_values( else: current_compound_dim.append(x) - for char in rule: + rule_len = len(rule) + rule_index = 0 + while rule_index < rule_len: + char = rule[rule_index] + rule_index += 1 if char == BATCHING: if (current_factor is not None or current_compound_dim is not None or value): raise ValueError( "Ellipsis can only be used at the beginning of a dimension") - add_factor(BATCHING) + if rule_index < rule_len and rule[rule_index].isdigit(): + batching_group_str = "" + while rule_index < rule_len and rule[rule_index].isdigit(): + batching_group_str += rule[rule_index] + rule_index += 1 + batching_group = str(int(batching_group_str)) + else: + batching_group = "0" + + add_factor(f"{BATCHING}{batching_group}") continue if char in "(), ": if current_factor is not None: @@ -342,9 +369,8 @@ def sdy_sharding_rule_to_mlir( factor_index = len(factors_to_indices_sizes) factors_to_indices_sizes[factor] = [factor_index, size] - def add_batching_dim_factor(batch_dim_order, factor_size): - ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) - add_factor(ellipsis_batch_dim_name, factor_size) + def add_batching_dim_factor(batch_grp, batch_dim_order, factor_size): + add_factor(_get_batching_dim_factor_name(batch_grp, batch_dim_order), factor_size) def build_dim_mapping_for_compound_factors(i, j, factors): accumulated_size = 1 @@ -365,23 +391,25 @@ def sdy_sharding_rule_to_mlir( # Add factors and their sizes in the order they appear in the rule, # including the batching dimensions represented by ellipsis. - ellipsis_rank = None + batching_group_to_rank: dict[str, int] = {} for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): value = tuple(mapping) - if value and value[0] == BATCHING: - has_batching = True + if value and _is_batching(value[0]): + batching_group = _get_batching_group(value[0]) value = value[1:] else: - has_batching = False + batching_group = None rule_rank = len(value) op_rank = get_rank_for_value(i) # The number of dimensions represented by ellipsis. current_batching_rank = 0 - if has_batching and op_rank >= rule_rank: + if batching_group is not None and op_rank >= rule_rank: current_batching_rank = op_rank - rule_rank - if has_batching: + if batching_group is not None: + ellipsis_rank = batching_group_to_rank.get(batching_group, None) if ellipsis_rank is None: ellipsis_rank = current_batching_rank + batching_group_to_rank[batching_group] = ellipsis_rank elif ellipsis_rank != current_batching_rank: raise ValueError( "Ellipsis represents different number of leading dimensions" @@ -394,7 +422,7 @@ def sdy_sharding_rule_to_mlir( f" {msg} has rank {op_rank}") for j in range(current_batching_rank): - add_batching_dim_factor(j, get_size_for_value_dim(i, j)) + add_batching_dim_factor(batching_group, j, get_size_for_value_dim(i, j)) for j, dim in enumerate(value): if isinstance(dim, str): @@ -408,20 +436,25 @@ def sdy_sharding_rule_to_mlir( for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): value = tuple(mapping) dim_mappings = [] - - if value and value[0] == BATCHING: + if value and _is_batching(value[0]): + batching_group = _get_batching_group(value[0]) value = value[1:] - if ellipsis_rank is None: - current_batching_rank = 0 + if batching_group in batching_group_to_rank: + # This type check error is not correct, disable it: + # Incompatible types in assignment (expression has type "int | None" + current_batching_rank = batching_group_to_rank.get(batching_group) # type: ignore else: - current_batching_rank = ellipsis_rank + raise ValueError("Unreachabled code") else: current_batching_rank = 0 + batching_group = None for j in range(current_batching_rank): + # This type check error is not correct, disable it: + # Argument 1 to "_get_batching_dim_factor_name" has incompatible type "str | None"; expected "str" [arg-type] dim_mappings.append( sdy.DimMappingAttr.get(factor_indices=[ - factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) + factors_to_indices_sizes[_get_batching_dim_factor_name(batching_group, j)][0]])) # type: ignore for j, dim in enumerate(value): if isinstance(dim, str): diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py index 3aed16510..f22721910 100644 --- a/tests/custom_partitioning_sharding_rule_test.py +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -50,8 +50,8 @@ class SdyShardingRuleTest(jtu.JaxTestCase): ArrayMapping("i_j", BATCHING) def test_value_mapping_str(self): - v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k") - self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')") + v = ArrayMapping(f"{BATCHING}2", "m", CompoundFactor("i", "j"), "k") + self.assertEqual(str(v), f"('{BATCHING}2', 'm', ('i', 'j'), 'k')") def test_sdy_sharding_rule_factor_size_not_used(self): with self.assertRaisesRegex(ValueError, "Factor k is not used"): @@ -158,17 +158,18 @@ class StrToSdyShardingRuleTest(jtu.JaxTestCase): str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})") def test_sharding_rule_factor_elementwise_add(self): - rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j") + # An ellipsis without a number ... is treated as the same as ...0. + rule = str_to_sdy_sharding_rule("...0 i j, ...1 i j -> ...i j") self.assertEqual( str(rule), - "SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i'," + "SdyShardingRule((('…0', 'i', 'j'), ('…1', 'i', 'j')), (('…0', 'i'," " 'j'),), {})") def test_sharding_rule_factor_vector_scalar_add(self): - rule = str_to_sdy_sharding_rule("...i, -> ...i") + rule = str_to_sdy_sharding_rule("...87 i, -> ...87 i") self.assertEqual( str(rule), - "SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})") + "SdyShardingRule((('…87', 'i'), ()), (('…87', 'i'),), {})") def test_sharding_rule_factor_reshape_combining(self): rule = str_to_sdy_sharding_rule("i j -> (i j)") @@ -316,7 +317,7 @@ class SdyShardingRuleConversionTest(jtu.JaxTestCase): rule = str_to_sdy_sharding_rule("..., ... -> ...") with self.assertRaisesRegex( ValueError, - "Batching dimension 1 corresponds to two sizes: 32 and 64"): + "Batching dimension 0_1 corresponds to two sizes: 32 and 64"): sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,],) @@ -464,5 +465,22 @@ class SdyShardingRuleConversionTest(jtu.JaxTestCase): "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") + def test_conversion_multiple_batching_groups(self): + opnd0 = self.create_tensor_value((4, 5, 16, 32)) + opnd1 = self.create_tensor_value((6, 7, 8, 32, 16)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((4, 5, 32, 16))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("... j i, ...1 i j -> ...i j") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>") + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())