[jax:custom_partitioning] Implement SdyShardingRule to support

Shardy custom_partitioning.

The parsing of the sharding rule string very closely follows how einops parses
their rules in einops/parsing.py.

When a SdyShardingRule object is constructed, we check the syntax of the Einsum
like notation string and its consistency with the user provided factor_sizes,
and report errors accordingly. This is done during f.def_partition.

When SdyShardingRule.build is called, during JAX to MLIR lowering, we check
the consistency between the Einsum like notation string, the factor_sizes
and the MLIR operation, and report errors accordingly.

PiperOrigin-RevId: 703187962
This commit is contained in:
Bixia Zheng 2024-12-05 11:32:43 -08:00 committed by jax authors
parent f73fa7a7ad
commit 2a4a0e8d6f
4 changed files with 787 additions and 0 deletions

View File

@ -193,6 +193,7 @@ py_library_providing_imports_info(
"_src/custom_batching.py",
"_src/custom_derivatives.py",
"_src/custom_partitioning.py",
"_src/custom_partitioning_sharding_rule.py",
"_src/custom_transpose.py",
"_src/debugging.py",
"_src/dispatch.py",

View File

@ -0,0 +1,380 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements SdyShardingRule."""
from collections import OrderedDict
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
_CompoundFactor = tuple[str, ...]
_DimMapping = tuple[str | _CompoundFactor, ...]
# A single character replacement for ... to simplify parsing.
_ELLIPSIS: str = ""
# A prefix for names of batching dimension factors, used for expanding the
# leading ... into factors.
_BATCHING_DIM_FACTOR_PREFIX = "?"
def _get_batching_dim_factor_name(batch_dim_order : int):
"""Constructs a factor name for a batching dimension.
We expand the leading ... into factors representing the batching dimensions
to support building the MLIR representation for the sharding rule. For this
reason, we construct a factor name that won't be used by users for the
batching dimensions.
"""
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}"
def _parse_values(
rule: str,
) -> tuple[_DimMapping, ...]:
"""Parses the LHS or RHS of an Einsum notation like string.
Converts each operand or result in the Einsum notation like string to a tuple
of _DimMapping. This very closely follows how einops parses their rules in
einops/parsing.py.
Args:
rule: The Einsum notation for the operands or results of an operation.
Returns:
The tuple of values.
Raises:
ValueError: If the rule is not balanced or contains unknown characters.
"""
# Remove unnecessary spaces in the rule to simplify the parsing process.
words = rule.split()
rule = " ".join(words)
# Similar to einops rules, an empty LHS/RHS has a single scalar value.
if not rule:
return ((),)
all_values = []
# Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the
# value may have 0 or more leading dimensions.
value = []
current_factor = None
# A value of None indicates the current dimension is not a compound dimension,
# while a value of [] indicates that we have just started parsing a compound
# dimension.
current_compound_dim: list[str] | None = None
def add_factor(x):
if current_compound_dim is None:
value.append(x)
else:
current_compound_dim.append(x)
for char in rule:
if char == _ELLIPSIS:
if (current_factor is not None or current_compound_dim is not None
or value):
raise ValueError(
"Ellipsis can only be used at the beginning of a dimension")
add_factor(_ELLIPSIS)
continue
if char in "(), ":
if current_factor is not None:
add_factor(current_factor)
current_factor = None
if char == "(":
if current_compound_dim is not None:
raise ValueError(
"Compound factors should be one level, nested brackets are not"
" allowed")
current_compound_dim = []
elif char == ")":
if current_compound_dim is None:
raise ValueError("Brackets are not balanced")
if len(current_compound_dim) <= 1:
raise ValueError("Brackets should contain at least two factors")
value.append(tuple(current_compound_dim))
current_compound_dim = None
elif char == ",":
all_values.append(tuple(value))
value = []
elif char == "_" or char.isdigit() or char.isalpha():
if current_factor is None:
if str.isdigit(char):
raise ValueError(f"Factor names have to start with a letter, but got '{char}'")
current_factor = char
else:
current_factor += char
else:
raise ValueError(f"Unknown character '{char}'")
if current_compound_dim is not None:
raise ValueError(f"Brackets are not balanced in rule: '{rule}'")
if current_factor is not None:
add_factor(current_factor)
all_values.append(tuple(value))
return tuple(all_values)
class SdyShardingRule:
"""A representation for Shardy sharding rule.
A SdyShardingRule includes an Enisum notation like string and an optional
list of factor sizes. A factor is a name in the Einsum notation. If a factor
is only used in compound factors, its size must be specified.
SdyShardingRule examples:
* Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k')
* Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k')
* A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j')
* Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2)
* An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...')
"""
def __init__(self, rule: str, **factor_sizes):
"""Constructs a SdyShardingRule object from the Einsum notation like string.
This is done by verifying that the input Einsum notation like string and
with optional factor sizes represents a valid sharding rule and converting
it to an internal representation.
Args:
rule: The Einsum notation like string for an operation.
**factor_sizes: The optional factor sizes.
Raises:
ValueError: If there is any problem with the rule or factor_sizes.
"""
if not isinstance(rule, str):
raise TypeError(f"rule must be a str, but got {type(rule)}")
if not all(isinstance(size, int) for size in factor_sizes.values()):
raise TypeError(
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
# Replace ... with a single char to simplify parsing.
if _ELLIPSIS in rule:
raise ValueError(f"Unknown character '{_ELLIPSIS}'")
if "." in rule:
rule = rule.replace("...", _ELLIPSIS)
if "." in rule:
raise ValueError("Character '.' must be used inside ellipsis '...'")
try:
operands, results = rule.split("->")
except ValueError as e:
raise ValueError(f"There is no -> in rule: '{rule}'") from e
self.operands = _parse_values(operands)
self.results = _parse_values(results)
# Find all factors and mark whether their size can be inferred.
factors_inferrable = dict()
for value in self.operands + self.results:
for dim in value:
if dim == _ELLIPSIS:
continue
if isinstance(dim, str):
factors_inferrable[dim] = True
else:
for factor in dim:
if factor not in factors_inferrable.keys():
factors_inferrable[factor] = False
# Check that factors in factor_sizes are used in the rule.
for factor in factor_sizes:
if factor not in factors_inferrable:
raise ValueError(
f"Factor {factor} is not used in the rule, but size is provided")
# Check that factors that are used for a whole dimension aren't in
# factor_sizes and factors that are never used for a whole dimension are
# in factor_sizes.
for factor, inferrable in factors_inferrable.items():
if factor not in factor_sizes and not inferrable:
raise ValueError(
f"Factor {factor} is only used in compound factors; must specify"
" its size")
if factor in factor_sizes and inferrable:
raise ValueError(
f"Factor {factor} represents a whole dimension; do not specify its"
" size")
self.factor_sizes = factor_sizes
def __str__(self):
return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})"
def build(
self,
operand_types: list[ir.Type],
result_types: list[ir.Type],) -> ir.Attribute:
"""Builds the MLIR representation for the sharding rule.
This is done by verifying that the rule is consistent with the types of
the operation and converting the Einsum notation like string to
OpShardingRuleAttr.
"""
if len(self.operands) != len(operand_types):
raise ValueError(
f"Sharding rule has {len(self.operands)} operands, but the operation"
f" has {len(operand_types)} operands"
)
if len(self.results) != len(result_types):
raise ValueError(
f"Sharding rule has {len(self.results)} results, but the operation"
f" has {len(result_types)} results"
)
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
types = operand_types + result_types
UNKNOWN = -1 # Representation for unknown factor size or factor index.
def get_message_for_value(i):
if i >= len(operand_types):
return f"{i - len(operand_types)}th result"
else:
return f"{i}th operand"
def get_rank_for_value(i):
return ir.ShapedType(types[i]).rank
def get_size_for_value_dim(i, j):
return ir.ShapedType(types[i]).shape[j]
def add_factor(factor, size):
"""Adds a factor to factors_to_indices_sizes.
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
if a factor is first used as in a compound factor and then used for a
whole dimension.
"""
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
if factor_index != UNKNOWN:
# Not the first time seeing the factor.
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
factor_or_batching_dim = (
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
else f"Batching dimension {factor[1:]}")
raise ValueError(
f"{factor_or_batching_dim} corresponds to two sizes:"
f" {factor_size} and {size}")
if size != UNKNOWN and factor_size == UNKNOWN:
factors_to_indices_sizes[factor] = [factor_index, size]
else:
# First time seeing the factor.
factor_index = len(factors_to_indices_sizes)
factors_to_indices_sizes[factor] = [factor_index, size]
def add_batching_dim_factor(batch_dim_order, factor_size):
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
add_factor(ellipsis_batch_dim_name, factor_size)
def build_dim_mapping_for_compound_factors(i, j, factors):
accumulated_size = 1
all_indices = []
for factor in factors:
factor_index, factor_size = factors_to_indices_sizes[factor]
accumulated_size *= factor_size
all_indices.append(factor_index)
dim_size = get_size_for_value_dim(i, j)
if accumulated_size != dim_size:
raise ValueError(
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
f" the size {accumulated_size} derived from the compound factors"
f" {factors}")
return sdy.DimMappingAttr.get(factor_indices=all_indices)
# Add factors and their sizes in the order they appear in the rule,
# including the batching dimensions represented by ellipsis.
ellipsis_rank = None
for i, value in enumerate(self.operands + self.results):
if value and value[0] == _ELLIPSIS:
has_ellipsis = True
value = value[1:]
else:
has_ellipsis = False
rule_rank = len(value)
op_rank = get_rank_for_value(i)
# The number of dimensions represented by ellipsis.
current_ellipsis_rank = 0
if has_ellipsis and op_rank > rule_rank:
current_ellipsis_rank = op_rank - rule_rank
if has_ellipsis:
if ellipsis_rank is None:
ellipsis_rank = current_ellipsis_rank
elif ellipsis_rank != current_ellipsis_rank:
raise ValueError(
"Ellipsis represents different number of leading dimensions"
f" {ellipsis_rank} and {current_ellipsis_rank}")
rule_rank += current_ellipsis_rank
if rule_rank != op_rank:
msg = get_message_for_value(i)
raise ValueError(
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
f" {msg} has rank {op_rank}")
for j in range(current_ellipsis_rank):
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
for j, dim in enumerate(value):
if isinstance(dim, str):
add_factor(
dim, get_size_for_value_dim(i, j + current_ellipsis_rank))
else:
for factor in dim:
add_factor(factor, self.factor_sizes.get(factor, UNKNOWN))
# Build the tensor mappings for each operand and result.
tensor_mappings = []
for i, value in enumerate(self.operands + self.results):
dim_mappings = []
if value and value[0] == _ELLIPSIS:
value = value[1:]
if ellipsis_rank is None:
current_ellipsis_rank = 0
else:
current_ellipsis_rank = ellipsis_rank
else:
current_ellipsis_rank = 0
for j in range(current_ellipsis_rank):
dim_mappings.append(
sdy.DimMappingAttr.get(factor_indices=[
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
for j, dim in enumerate(value):
if isinstance(dim, str):
dim_mappings.append(
sdy.DimMappingAttr.get(
factor_indices=[factors_to_indices_sizes[dim][0]]))
else:
dim_mappings.append(
build_dim_mapping_for_compound_factors(
i, j + current_ellipsis_rank, dim))
tensor_mappings.append(
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
op_sharding_rule = sdy.OpShardingRuleAttr.get(
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
operand_mappings=tensor_mappings[0:len(operand_types)],
result_mappings=tensor_mappings[len(operand_types):])
return op_sharding_rule

View File

@ -1557,6 +1557,16 @@ jax_multiplatform_test(
tags = ["multiaccelerator"],
)
jax_py_test(
name = "custom_partitioning_sharding_rule_test",
srcs = ["custom_partitioning_sharding_rule_test.py"],
deps = [
"//jax",
"//jax:experimental",
"//jax:test_util",
],
)
exports_files(
[
"api_test.py",

View File

@ -0,0 +1,396 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
from jax._src.custom_partitioning_sharding_rule import SdyShardingRule
from jax._src.lib.mlir.dialects import hlo as stablehlo
class SdyShardingRuleTest(jtu.JaxTestCase):
def test_rule_is_not_a_str(self):
with self.assertRaisesRegex(TypeError, "rule must be a str"):
SdyShardingRule(1)
def test_factor_sizes_is_not_a_proper_dict(self):
with self.assertRaisesRegex(
TypeError, "factor_sizes must be a dict of str to int"):
SdyShardingRule("i->j", i="j")
def test_sharding_rule_ellipsis_not_complete(self):
with self.assertRaisesRegex(
ValueError, "Character '.' must be used inside ellipsis '...'"):
SdyShardingRule(".i -> j")
def test_sharding_rule_invalid_factor_name(self):
with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"):
SdyShardingRule("2i -> j")
def test_sharding_rule_missing_results(self):
with self.assertRaisesRegex(ValueError, "There is no -> in rule"):
SdyShardingRule("i")
def test_sharding_rule_inbalenced_brackets(self):
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
SdyShardingRule("i j, k)->j")
def test_sharding_rule_inbalenced_brackets2(self):
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
SdyShardingRule("i (j k->j")
def test_sharding_rule_empty_compound_dim(self):
with self.assertRaisesRegex(
ValueError, "Brackets should contain at least two factors"):
SdyShardingRule("i ( ) j k->j")
def test_sharding_rule_one_factorcompound_dim(self):
with self.assertRaisesRegex(
ValueError, "Brackets should contain at least two factors"):
SdyShardingRule("i (j ) k->j")
def test_sharding_rule_nested_brackets(self):
with self.assertRaisesRegex(
ValueError, "Compound factors should be one level"):
SdyShardingRule("i (j (k))->j")
def test_sharding_rule_unknown_char(self):
with self.assertRaisesRegex(ValueError, "Unknown character"):
SdyShardingRule("i; j->j")
def test_sharding_rule_unknown_single_char_ellipse(self):
with self.assertRaisesRegex(ValueError, "Unknown character"):
SdyShardingRule("…j->…j")
def test_sharding_rule_ellipsis_not_leading_dim(self):
with self.assertRaisesRegex(
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
SdyShardingRule("i ... -> j")
def test_sharding_rule_ellipsis_inside_compound_dim(self):
with self.assertRaisesRegex(
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
SdyShardingRule("i, (..., j) -> j")
def test_sharding_rule_scalar_operand_scalar_result(self):
rule = SdyShardingRule("->")
self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})")
def test_sharding_rule_one_scalar_operand(self):
rule = SdyShardingRule("i j, , k->j")
self.assertEqual(
str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})")
def test_sharding_rule_factor_size_not_used(self):
with self.assertRaisesRegex(ValueError, "Factor k is not used"):
SdyShardingRule("i->j", k=10)
def test_sharding_rule_factor_size_not_necessary(self):
with self.assertRaisesRegex(
ValueError,
"Factor i represents a whole dimension; do not specify its size"):
SdyShardingRule("i->j", i=10)
def test_sharding_rule_compound_factor_size_not_necessary(self):
with self.assertRaisesRegex(
ValueError,
"Factor i represents a whole dimension; do not specify its size"):
SdyShardingRule("(i j) -> i", i=10, j=20)
def test_sharding_rule_factor_sizes_missing(self):
with self.assertRaisesRegex(
ValueError,
"Factor k is only used in compound factors; must specify its size"):
SdyShardingRule("i j -> (j k)")
def test_sharding_rule_factor_elementwise_add(self):
rule = SdyShardingRule("... i j, ...i j -> ...i j")
self.assertEqual(
str(rule),
"SdyShardingRule((('', 'i', 'j'), ('', 'i', 'j')), (('', 'i',"
" 'j'),), {})")
def test_sharding_rule_factor_vector_scalar_add(self):
rule = SdyShardingRule("...i, -> ...i")
self.assertEqual(
str(rule),
"SdyShardingRule((('', 'i'), ()), (('', 'i'),), {})")
def test_sharding_rule_factor_reshape_combining(self):
rule = SdyShardingRule("i j -> (i j)")
self.assertEqual(
str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})")
def test_sharding_rule_factor_reshape_reordering(self):
rule = SdyShardingRule("(j i) -> (i j)", i=10, j=20)
self.assertEqual(
str(rule),
"SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':"
" 20})")
def test_sharding_rule_factor_compound_then_individual(self):
rule = SdyShardingRule("(i j) (j k) i -> j k")
self.assertEqual(
str(rule),
"SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})")
def test_sharding_rule_factor_individual_then_compound(self):
rule = SdyShardingRule("i j k -> (i j) (j k)")
self.assertEqual(
str(rule),
"SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})")
def test_sharding_rule_factor_infer_k(self):
rule = SdyShardingRule("_i (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20)
self.assertEqual(
str(rule),
"SdyShardingRule((('_i', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))"
",), {'k': 10, 'm': 10, 'bar_24': 20})")
class SdyShardingRuleConversionTest(jtu.JaxTestCase):
def run(self, result=None):
with ir.Context() as ctx, ir.Location.unknown(ctx):
sdy.register_dialect(ctx)
stablehlo.register_dialect(ctx)
module = ir.Module.create()
with ir.InsertionPoint(module.body):
super().run(result)
def get_tensor_type(self, shape):
return ir.RankedTensorType.get(shape, ir.F32Type.get())
def create_tensor_value(self, shape):
return ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type(shape)],
attributes=dict(call_target_name=ir.StringAttr.get("dummy_target"))
).result
def test_conversion_rule_op_mismatch_in_operands_num(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = SdyShardingRule("i j-> i j")
with self.assertRaisesRegex(
ValueError,
"Sharding rule has 1 operands, but the operation has 2 operands"):
rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_rule_op_mismatch_in_operands_rank(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = SdyShardingRule("i j, i j k-> i j")
with self.assertRaisesRegex(
ValueError,
"Sharding rule 1th operand has rank 3, but the operation 1th "
"operand has rank 2"):
rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_rule_op_mismatch_in_results_num(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0,
opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = SdyShardingRule("i j, i j -> i j, i j")
with self.assertRaisesRegex(
ValueError,
"Sharding rule has 2 results, but the operation has 1 results"):
rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_rule_op_mismatch_in_results_dim(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = SdyShardingRule("i j, i j -> i j k")
with self.assertRaisesRegex(
ValueError,
"Sharding rule 0th result has rank 3, but the operation 0th "
"result has rank 2"):
rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_factor_has_two_sizes(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 64))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = SdyShardingRule("i j, i j -> i j")
with self.assertRaisesRegex(
ValueError,
"Factor j corresponds to two sizes: 32 and 64"):
rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_batching_dim_has_two_sizes(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 64))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = SdyShardingRule("..., ... -> ...")
with self.assertRaisesRegex(
ValueError,
"Batching dimension 1 corresponds to two sizes: 32 and 64"):
rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,],)
def test_conversion_compound_dimension_size_mismatch(self):
opnd = self.create_tensor_value((2, 4))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((9,))],
operands=[opnd,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = SdyShardingRule("i j -> (i j)")
with self.assertRaisesRegex(
ValueError,
"0th result actual size 9 doesn't match the size 8 derived from the"
" compound factors"):
rule.build(
[result.operands[0].type],
[result.result.type,])
def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16,))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = SdyShardingRule("..., ... -> ...")
with self.assertRaisesRegex(
ValueError,
"Ellipsis represents different number of leading dimensions 2 and 1"):
rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
def test_conversion_elementwise_rule_scalar_instance(self):
opnd0 = self.create_tensor_value(())
opnd1 = self.create_tensor_value(())
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type(())],
operands=[opnd0, opnd1],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = SdyShardingRule("..., ... -> ...")
mlir_rule = rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([], [])->([])>")
def test_conversion_elementwise_rule_2D_instance(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((16, 32))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = SdyShardingRule("..., ... -> ...")
mlir_rule = rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>")
def test_conversion_vector_scalar_add_2D_instance(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value(())
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 32))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
rule = SdyShardingRule("..., -> ...")
mlir_rule = rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>")
def test_conversion_reshape_rule(self):
opnd0 = self.create_tensor_value((2, 4))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((8,))],
operands=[opnd0,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = SdyShardingRule("i j -> (i j)")
mlir_rule = rule.build(
[result.operands[0].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>")
def test_conversion_contracting_dim_matmul(self):
opnd0 = self.create_tensor_value((16, 32))
opnd1 = self.create_tensor_value((32, 8))
result = ir.Operation.create(
"stablehlo.custom_call",
results=[self.get_tensor_type((16, 8))],
operands=[opnd0, opnd1,],
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
rule = SdyShardingRule("... contracting_dim, contracting_dim k -> ... k")
mlir_rule = rule.build(
[result.operands[0].type, result.operands[1].type],
[result.result.type,])
self.assertEqual(
str(mlir_rule),
"#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>")
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())