From 2a4a0e8d6fb36b59f9c6f24e0018d42c8c8d8ee9 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 5 Dec 2024 11:32:43 -0800 Subject: [PATCH] [jax:custom_partitioning] Implement SdyShardingRule to support Shardy custom_partitioning. The parsing of the sharding rule string very closely follows how einops parses their rules in einops/parsing.py. When a SdyShardingRule object is constructed, we check the syntax of the Einsum like notation string and its consistency with the user provided factor_sizes, and report errors accordingly. This is done during f.def_partition. When SdyShardingRule.build is called, during JAX to MLIR lowering, we check the consistency between the Einsum like notation string, the factor_sizes and the MLIR operation, and report errors accordingly. PiperOrigin-RevId: 703187962 --- jax/BUILD | 1 + jax/_src/custom_partitioning_sharding_rule.py | 380 +++++++++++++++++ tests/BUILD | 10 + .../custom_partitioning_sharding_rule_test.py | 396 ++++++++++++++++++ 4 files changed, 787 insertions(+) create mode 100644 jax/_src/custom_partitioning_sharding_rule.py create mode 100644 tests/custom_partitioning_sharding_rule_test.py diff --git a/jax/BUILD b/jax/BUILD index 053b05027..31020eb1d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -193,6 +193,7 @@ py_library_providing_imports_info( "_src/custom_batching.py", "_src/custom_derivatives.py", "_src/custom_partitioning.py", + "_src/custom_partitioning_sharding_rule.py", "_src/custom_transpose.py", "_src/debugging.py", "_src/dispatch.py", diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py new file mode 100644 index 000000000..5193c9126 --- /dev/null +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -0,0 +1,380 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements SdyShardingRule.""" + +from collections import OrderedDict + +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import sdy + + +_CompoundFactor = tuple[str, ...] +_DimMapping = tuple[str | _CompoundFactor, ...] + +# A single character replacement for ... to simplify parsing. +_ELLIPSIS: str = "…" + +# A prefix for names of batching dimension factors, used for expanding the +# leading ... into factors. +_BATCHING_DIM_FACTOR_PREFIX = "?" + +def _get_batching_dim_factor_name(batch_dim_order : int): + """Constructs a factor name for a batching dimension. + + We expand the leading ... into factors representing the batching dimensions + to support building the MLIR representation for the sharding rule. For this + reason, we construct a factor name that won't be used by users for the + batching dimensions. + """ + return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}" + +def _parse_values( + rule: str, +) -> tuple[_DimMapping, ...]: + """Parses the LHS or RHS of an Einsum notation like string. + + Converts each operand or result in the Einsum notation like string to a tuple + of _DimMapping. This very closely follows how einops parses their rules in + einops/parsing.py. + + Args: + rule: The Einsum notation for the operands or results of an operation. + + Returns: + The tuple of values. + + Raises: + ValueError: If the rule is not balanced or contains unknown characters. + """ + + # Remove unnecessary spaces in the rule to simplify the parsing process. + words = rule.split() + rule = " ".join(words) + + # Similar to einops rules, an empty LHS/RHS has a single scalar value. + if not rule: + return ((),) + + all_values = [] + # Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the + # value may have 0 or more leading dimensions. + value = [] + current_factor = None + # A value of None indicates the current dimension is not a compound dimension, + # while a value of [] indicates that we have just started parsing a compound + # dimension. + current_compound_dim: list[str] | None = None + + def add_factor(x): + if current_compound_dim is None: + value.append(x) + else: + current_compound_dim.append(x) + + for char in rule: + if char == _ELLIPSIS: + if (current_factor is not None or current_compound_dim is not None + or value): + raise ValueError( + "Ellipsis can only be used at the beginning of a dimension") + add_factor(_ELLIPSIS) + continue + if char in "(), ": + if current_factor is not None: + add_factor(current_factor) + current_factor = None + if char == "(": + if current_compound_dim is not None: + raise ValueError( + "Compound factors should be one level, nested brackets are not" + " allowed") + current_compound_dim = [] + elif char == ")": + if current_compound_dim is None: + raise ValueError("Brackets are not balanced") + if len(current_compound_dim) <= 1: + raise ValueError("Brackets should contain at least two factors") + value.append(tuple(current_compound_dim)) + current_compound_dim = None + elif char == ",": + all_values.append(tuple(value)) + value = [] + elif char == "_" or char.isdigit() or char.isalpha(): + if current_factor is None: + if str.isdigit(char): + raise ValueError(f"Factor names have to start with a letter, but got '{char}'") + current_factor = char + else: + current_factor += char + else: + raise ValueError(f"Unknown character '{char}'") + + if current_compound_dim is not None: + raise ValueError(f"Brackets are not balanced in rule: '{rule}'") + if current_factor is not None: + add_factor(current_factor) + all_values.append(tuple(value)) + + return tuple(all_values) + + +class SdyShardingRule: + """A representation for Shardy sharding rule. + + A SdyShardingRule includes an Enisum notation like string and an optional + list of factor sizes. A factor is a name in the Einsum notation. If a factor + is only used in compound factors, its size must be specified. + + SdyShardingRule examples: + + * Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k') + * Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k') + * A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j') + * Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2) + * An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...') + """ + + def __init__(self, rule: str, **factor_sizes): + """Constructs a SdyShardingRule object from the Einsum notation like string. + + This is done by verifying that the input Einsum notation like string and + with optional factor sizes represents a valid sharding rule and converting + it to an internal representation. + + Args: + rule: The Einsum notation like string for an operation. + **factor_sizes: The optional factor sizes. + + Raises: + ValueError: If there is any problem with the rule or factor_sizes. + """ + if not isinstance(rule, str): + raise TypeError(f"rule must be a str, but got {type(rule)}") + if not all(isinstance(size, int) for size in factor_sizes.values()): + raise TypeError( + f"factor_sizes must be a dict of str to int, but got {factor_sizes}") + + # Replace ... with a single char to simplify parsing. + if _ELLIPSIS in rule: + raise ValueError(f"Unknown character '{_ELLIPSIS}'") + if "." in rule: + rule = rule.replace("...", _ELLIPSIS) + if "." in rule: + raise ValueError("Character '.' must be used inside ellipsis '...'") + + try: + operands, results = rule.split("->") + except ValueError as e: + raise ValueError(f"There is no -> in rule: '{rule}'") from e + + self.operands = _parse_values(operands) + self.results = _parse_values(results) + + # Find all factors and mark whether their size can be inferred. + factors_inferrable = dict() + for value in self.operands + self.results: + for dim in value: + if dim == _ELLIPSIS: + continue + if isinstance(dim, str): + factors_inferrable[dim] = True + else: + for factor in dim: + if factor not in factors_inferrable.keys(): + factors_inferrable[factor] = False + + # Check that factors in factor_sizes are used in the rule. + for factor in factor_sizes: + if factor not in factors_inferrable: + raise ValueError( + f"Factor {factor} is not used in the rule, but size is provided") + + # Check that factors that are used for a whole dimension aren't in + # factor_sizes and factors that are never used for a whole dimension are + # in factor_sizes. + for factor, inferrable in factors_inferrable.items(): + if factor not in factor_sizes and not inferrable: + raise ValueError( + f"Factor {factor} is only used in compound factors; must specify" + " its size") + if factor in factor_sizes and inferrable: + raise ValueError( + f"Factor {factor} represents a whole dimension; do not specify its" + " size") + + self.factor_sizes = factor_sizes + + def __str__(self): + return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})" + + def build( + self, + operand_types: list[ir.Type], + result_types: list[ir.Type],) -> ir.Attribute: + """Builds the MLIR representation for the sharding rule. + + This is done by verifying that the rule is consistent with the types of + the operation and converting the Einsum notation like string to + OpShardingRuleAttr. + """ + if len(self.operands) != len(operand_types): + raise ValueError( + f"Sharding rule has {len(self.operands)} operands, but the operation" + f" has {len(operand_types)} operands" + ) + if len(self.results) != len(result_types): + raise ValueError( + f"Sharding rule has {len(self.results)} results, but the operation" + f" has {len(result_types)} results" + ) + + factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() + types = operand_types + result_types + UNKNOWN = -1 # Representation for unknown factor size or factor index. + + def get_message_for_value(i): + if i >= len(operand_types): + return f"{i - len(operand_types)}th result" + else: + return f"{i}th operand" + + def get_rank_for_value(i): + return ir.ShapedType(types[i]).rank + + def get_size_for_value_dim(i, j): + return ir.ShapedType(types[i]).shape[j] + + def add_factor(factor, size): + """Adds a factor to factors_to_indices_sizes. + + `size` may be a dimensions size, a user specified factor size, or UNKNOWN + if a factor is first used as in a compound factor and then used for a + whole dimension. + """ + factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) + if factor_index != UNKNOWN: + # Not the first time seeing the factor. + if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: + factor_or_batching_dim = ( + f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor + else f"Batching dimension {factor[1:]}") + raise ValueError( + f"{factor_or_batching_dim} corresponds to two sizes:" + f" {factor_size} and {size}") + if size != UNKNOWN and factor_size == UNKNOWN: + factors_to_indices_sizes[factor] = [factor_index, size] + else: + # First time seeing the factor. + factor_index = len(factors_to_indices_sizes) + factors_to_indices_sizes[factor] = [factor_index, size] + + def add_batching_dim_factor(batch_dim_order, factor_size): + ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) + add_factor(ellipsis_batch_dim_name, factor_size) + + def build_dim_mapping_for_compound_factors(i, j, factors): + accumulated_size = 1 + all_indices = [] + for factor in factors: + factor_index, factor_size = factors_to_indices_sizes[factor] + accumulated_size *= factor_size + all_indices.append(factor_index) + + dim_size = get_size_for_value_dim(i, j) + if accumulated_size != dim_size: + raise ValueError( + f"{get_message_for_value(i)} actual size {dim_size} doesn't match" + f" the size {accumulated_size} derived from the compound factors" + f" {factors}") + + return sdy.DimMappingAttr.get(factor_indices=all_indices) + + # Add factors and their sizes in the order they appear in the rule, + # including the batching dimensions represented by ellipsis. + ellipsis_rank = None + for i, value in enumerate(self.operands + self.results): + if value and value[0] == _ELLIPSIS: + has_ellipsis = True + value = value[1:] + else: + has_ellipsis = False + rule_rank = len(value) + op_rank = get_rank_for_value(i) + # The number of dimensions represented by ellipsis. + current_ellipsis_rank = 0 + if has_ellipsis and op_rank > rule_rank: + current_ellipsis_rank = op_rank - rule_rank + if has_ellipsis: + if ellipsis_rank is None: + ellipsis_rank = current_ellipsis_rank + elif ellipsis_rank != current_ellipsis_rank: + raise ValueError( + "Ellipsis represents different number of leading dimensions" + f" {ellipsis_rank} and {current_ellipsis_rank}") + rule_rank += current_ellipsis_rank + if rule_rank != op_rank: + msg = get_message_for_value(i) + raise ValueError( + f"Sharding rule {msg} has rank {rule_rank}, but the operation" + f" {msg} has rank {op_rank}") + + for j in range(current_ellipsis_rank): + add_batching_dim_factor(j, get_size_for_value_dim(i, j)) + + for j, dim in enumerate(value): + if isinstance(dim, str): + add_factor( + dim, get_size_for_value_dim(i, j + current_ellipsis_rank)) + else: + for factor in dim: + add_factor(factor, self.factor_sizes.get(factor, UNKNOWN)) + + # Build the tensor mappings for each operand and result. + tensor_mappings = [] + for i, value in enumerate(self.operands + self.results): + dim_mappings = [] + + if value and value[0] == _ELLIPSIS: + value = value[1:] + if ellipsis_rank is None: + current_ellipsis_rank = 0 + else: + current_ellipsis_rank = ellipsis_rank + else: + current_ellipsis_rank = 0 + + for j in range(current_ellipsis_rank): + dim_mappings.append( + sdy.DimMappingAttr.get(factor_indices=[ + factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) + + for j, dim in enumerate(value): + if isinstance(dim, str): + dim_mappings.append( + sdy.DimMappingAttr.get( + factor_indices=[factors_to_indices_sizes[dim][0]])) + else: + dim_mappings.append( + build_dim_mapping_for_compound_factors( + i, j + current_ellipsis_rank, dim)) + + tensor_mappings.append( + sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) + + op_sharding_rule = sdy.OpShardingRuleAttr.get( + factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], + operand_mappings=tensor_mappings[0:len(operand_types)], + result_mappings=tensor_mappings[len(operand_types):]) + return op_sharding_rule diff --git a/tests/BUILD b/tests/BUILD index ce887181b..f80f17e54 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1557,6 +1557,16 @@ jax_multiplatform_test( tags = ["multiaccelerator"], ) +jax_py_test( + name = "custom_partitioning_sharding_rule_test", + srcs = ["custom_partitioning_sharding_rule_test.py"], + deps = [ + "//jax", + "//jax:experimental", + "//jax:test_util", + ], +) + exports_files( [ "api_test.py", diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py new file mode 100644 index 000000000..2aac4e048 --- /dev/null +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -0,0 +1,396 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from jax._src import test_util as jtu +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import sdy +from jax._src.custom_partitioning_sharding_rule import SdyShardingRule +from jax._src.lib.mlir.dialects import hlo as stablehlo + + +class SdyShardingRuleTest(jtu.JaxTestCase): + + def test_rule_is_not_a_str(self): + with self.assertRaisesRegex(TypeError, "rule must be a str"): + SdyShardingRule(1) + + def test_factor_sizes_is_not_a_proper_dict(self): + with self.assertRaisesRegex( + TypeError, "factor_sizes must be a dict of str to int"): + SdyShardingRule("i->j", i="j") + + def test_sharding_rule_ellipsis_not_complete(self): + with self.assertRaisesRegex( + ValueError, "Character '.' must be used inside ellipsis '...'"): + SdyShardingRule(".i -> j") + + def test_sharding_rule_invalid_factor_name(self): + with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): + SdyShardingRule("2i -> j") + + def test_sharding_rule_missing_results(self): + with self.assertRaisesRegex(ValueError, "There is no -> in rule"): + SdyShardingRule("i") + + def test_sharding_rule_inbalenced_brackets(self): + with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): + SdyShardingRule("i j, k)->j") + + def test_sharding_rule_inbalenced_brackets2(self): + with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): + SdyShardingRule("i (j k->j") + + def test_sharding_rule_empty_compound_dim(self): + with self.assertRaisesRegex( + ValueError, "Brackets should contain at least two factors"): + SdyShardingRule("i ( ) j k->j") + + def test_sharding_rule_one_factorcompound_dim(self): + with self.assertRaisesRegex( + ValueError, "Brackets should contain at least two factors"): + SdyShardingRule("i (j ) k->j") + + def test_sharding_rule_nested_brackets(self): + with self.assertRaisesRegex( + ValueError, "Compound factors should be one level"): + SdyShardingRule("i (j (k))->j") + + def test_sharding_rule_unknown_char(self): + with self.assertRaisesRegex(ValueError, "Unknown character"): + SdyShardingRule("i; j->j") + + def test_sharding_rule_unknown_single_char_ellipse(self): + with self.assertRaisesRegex(ValueError, "Unknown character"): + SdyShardingRule("…j->…j") + + def test_sharding_rule_ellipsis_not_leading_dim(self): + with self.assertRaisesRegex( + ValueError, "Ellipsis can only be used at the beginning of a dimension"): + SdyShardingRule("i ... -> j") + + def test_sharding_rule_ellipsis_inside_compound_dim(self): + with self.assertRaisesRegex( + ValueError, "Ellipsis can only be used at the beginning of a dimension"): + SdyShardingRule("i, (..., j) -> j") + + def test_sharding_rule_scalar_operand_scalar_result(self): + rule = SdyShardingRule("->") + self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})") + + def test_sharding_rule_one_scalar_operand(self): + rule = SdyShardingRule("i j, , k->j") + self.assertEqual( + str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})") + + def test_sharding_rule_factor_size_not_used(self): + with self.assertRaisesRegex(ValueError, "Factor k is not used"): + SdyShardingRule("i->j", k=10) + + def test_sharding_rule_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule("i->j", i=10) + + def test_sharding_rule_compound_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule("(i j) -> i", i=10, j=20) + + def test_sharding_rule_factor_sizes_missing(self): + with self.assertRaisesRegex( + ValueError, + "Factor k is only used in compound factors; must specify its size"): + SdyShardingRule("i j -> (j k)") + + def test_sharding_rule_factor_elementwise_add(self): + rule = SdyShardingRule("... i j, ...i j -> ...i j") + self.assertEqual( + str(rule), + "SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i'," + " 'j'),), {})") + + def test_sharding_rule_factor_vector_scalar_add(self): + rule = SdyShardingRule("...i, -> ...i") + self.assertEqual( + str(rule), + "SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})") + + def test_sharding_rule_factor_reshape_combining(self): + rule = SdyShardingRule("i j -> (i j)") + self.assertEqual( + str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})") + + def test_sharding_rule_factor_reshape_reordering(self): + rule = SdyShardingRule("(j i) -> (i j)", i=10, j=20) + self.assertEqual( + str(rule), + "SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':" + " 20})") + + def test_sharding_rule_factor_compound_then_individual(self): + rule = SdyShardingRule("(i j) (j k) i -> j k") + self.assertEqual( + str(rule), + "SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})") + + def test_sharding_rule_factor_individual_then_compound(self): + rule = SdyShardingRule("i j k -> (i j) (j k)") + self.assertEqual( + str(rule), + "SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})") + + def test_sharding_rule_factor_infer_k(self): + rule = SdyShardingRule("_i (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20) + self.assertEqual( + str(rule), + "SdyShardingRule((('_i', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" + ",), {'k': 10, 'm': 10, 'bar_24': 20})") + + +class SdyShardingRuleConversionTest(jtu.JaxTestCase): + + def run(self, result=None): + with ir.Context() as ctx, ir.Location.unknown(ctx): + sdy.register_dialect(ctx) + stablehlo.register_dialect(ctx) + module = ir.Module.create() + with ir.InsertionPoint(module.body): + super().run(result) + + def get_tensor_type(self, shape): + return ir.RankedTensorType.get(shape, ir.F32Type.get()) + + def create_tensor_value(self, shape): + return ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type(shape)], + attributes=dict(call_target_name=ir.StringAttr.get("dummy_target")) + ).result + + def test_conversion_rule_op_mismatch_in_operands_num(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("i j-> i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule has 1 operands, but the operation has 2 operands"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_operands_rank(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("i j, i j k-> i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 1th operand has rank 3, but the operation 1th " + "operand has rank 2"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_results_num(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, + opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("i j, i j -> i j, i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule has 2 results, but the operation has 1 results"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_results_dim(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j, i j -> i j k") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 0th result has rank 3, but the operation 0th " + "result has rank 2"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_factor_has_two_sizes(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 64))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j, i j -> i j") + with self.assertRaisesRegex( + ValueError, + "Factor j corresponds to two sizes: 32 and 64"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_batching_dim_has_two_sizes(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 64))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("..., ... -> ...") + with self.assertRaisesRegex( + ValueError, + "Batching dimension 1 corresponds to two sizes: 32 and 64"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,],) + + def test_conversion_compound_dimension_size_mismatch(self): + opnd = self.create_tensor_value((2, 4)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((9,))], + operands=[opnd,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j -> (i j)") + with self.assertRaisesRegex( + ValueError, + "0th result actual size 9 doesn't match the size 8 derived from the" + " compound factors"): + rule.build( + [result.operands[0].type], + [result.result.type,]) + + def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16,)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("..., ... -> ...") + with self.assertRaisesRegex( + ValueError, + "Ellipsis represents different number of leading dimensions 2 and 1"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_elementwise_rule_scalar_instance(self): + opnd0 = self.create_tensor_value(()) + opnd1 = self.create_tensor_value(()) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type(())], + operands=[opnd0, opnd1], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("..., ... -> ...") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([], [])->([])>") + + def test_conversion_elementwise_rule_2D_instance(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("..., ... -> ...") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>") + + def test_conversion_vector_scalar_add_2D_instance(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value(()) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("..., -> ...") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>") + + def test_conversion_reshape_rule(self): + opnd0 = self.create_tensor_value((2, 4)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((8,))], + operands=[opnd0,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j -> (i j)") + mlir_rule = rule.build( + [result.operands[0].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>") + + def test_conversion_contracting_dim_matmul(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((32, 8)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 8))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("... contracting_dim, contracting_dim k -> ... k") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())