mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[jax:custom_partitioning] Implement SdyShardingRule to support
Shardy custom_partitioning. The parsing of the sharding rule string very closely follows how einops parses their rules in einops/parsing.py. When a SdyShardingRule object is constructed, we check the syntax of the Einsum like notation string and its consistency with the user provided factor_sizes, and report errors accordingly. This is done during f.def_partition. When SdyShardingRule.build is called, during JAX to MLIR lowering, we check the consistency between the Einsum like notation string, the factor_sizes and the MLIR operation, and report errors accordingly. PiperOrigin-RevId: 703187962
This commit is contained in:
parent
f73fa7a7ad
commit
2a4a0e8d6f
@ -193,6 +193,7 @@ py_library_providing_imports_info(
|
||||
"_src/custom_batching.py",
|
||||
"_src/custom_derivatives.py",
|
||||
"_src/custom_partitioning.py",
|
||||
"_src/custom_partitioning_sharding_rule.py",
|
||||
"_src/custom_transpose.py",
|
||||
"_src/debugging.py",
|
||||
"_src/dispatch.py",
|
||||
|
380
jax/_src/custom_partitioning_sharding_rule.py
Normal file
380
jax/_src/custom_partitioning_sharding_rule.py
Normal file
@ -0,0 +1,380 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implements SdyShardingRule."""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
|
||||
|
||||
_CompoundFactor = tuple[str, ...]
|
||||
_DimMapping = tuple[str | _CompoundFactor, ...]
|
||||
|
||||
# A single character replacement for ... to simplify parsing.
|
||||
_ELLIPSIS: str = "…"
|
||||
|
||||
# A prefix for names of batching dimension factors, used for expanding the
|
||||
# leading ... into factors.
|
||||
_BATCHING_DIM_FACTOR_PREFIX = "?"
|
||||
|
||||
def _get_batching_dim_factor_name(batch_dim_order : int):
|
||||
"""Constructs a factor name for a batching dimension.
|
||||
|
||||
We expand the leading ... into factors representing the batching dimensions
|
||||
to support building the MLIR representation for the sharding rule. For this
|
||||
reason, we construct a factor name that won't be used by users for the
|
||||
batching dimensions.
|
||||
"""
|
||||
return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}"
|
||||
|
||||
def _parse_values(
|
||||
rule: str,
|
||||
) -> tuple[_DimMapping, ...]:
|
||||
"""Parses the LHS or RHS of an Einsum notation like string.
|
||||
|
||||
Converts each operand or result in the Einsum notation like string to a tuple
|
||||
of _DimMapping. This very closely follows how einops parses their rules in
|
||||
einops/parsing.py.
|
||||
|
||||
Args:
|
||||
rule: The Einsum notation for the operands or results of an operation.
|
||||
|
||||
Returns:
|
||||
The tuple of values.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rule is not balanced or contains unknown characters.
|
||||
"""
|
||||
|
||||
# Remove unnecessary spaces in the rule to simplify the parsing process.
|
||||
words = rule.split()
|
||||
rule = " ".join(words)
|
||||
|
||||
# Similar to einops rules, an empty LHS/RHS has a single scalar value.
|
||||
if not rule:
|
||||
return ((),)
|
||||
|
||||
all_values = []
|
||||
# Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the
|
||||
# value may have 0 or more leading dimensions.
|
||||
value = []
|
||||
current_factor = None
|
||||
# A value of None indicates the current dimension is not a compound dimension,
|
||||
# while a value of [] indicates that we have just started parsing a compound
|
||||
# dimension.
|
||||
current_compound_dim: list[str] | None = None
|
||||
|
||||
def add_factor(x):
|
||||
if current_compound_dim is None:
|
||||
value.append(x)
|
||||
else:
|
||||
current_compound_dim.append(x)
|
||||
|
||||
for char in rule:
|
||||
if char == _ELLIPSIS:
|
||||
if (current_factor is not None or current_compound_dim is not None
|
||||
or value):
|
||||
raise ValueError(
|
||||
"Ellipsis can only be used at the beginning of a dimension")
|
||||
add_factor(_ELLIPSIS)
|
||||
continue
|
||||
if char in "(), ":
|
||||
if current_factor is not None:
|
||||
add_factor(current_factor)
|
||||
current_factor = None
|
||||
if char == "(":
|
||||
if current_compound_dim is not None:
|
||||
raise ValueError(
|
||||
"Compound factors should be one level, nested brackets are not"
|
||||
" allowed")
|
||||
current_compound_dim = []
|
||||
elif char == ")":
|
||||
if current_compound_dim is None:
|
||||
raise ValueError("Brackets are not balanced")
|
||||
if len(current_compound_dim) <= 1:
|
||||
raise ValueError("Brackets should contain at least two factors")
|
||||
value.append(tuple(current_compound_dim))
|
||||
current_compound_dim = None
|
||||
elif char == ",":
|
||||
all_values.append(tuple(value))
|
||||
value = []
|
||||
elif char == "_" or char.isdigit() or char.isalpha():
|
||||
if current_factor is None:
|
||||
if str.isdigit(char):
|
||||
raise ValueError(f"Factor names have to start with a letter, but got '{char}'")
|
||||
current_factor = char
|
||||
else:
|
||||
current_factor += char
|
||||
else:
|
||||
raise ValueError(f"Unknown character '{char}'")
|
||||
|
||||
if current_compound_dim is not None:
|
||||
raise ValueError(f"Brackets are not balanced in rule: '{rule}'")
|
||||
if current_factor is not None:
|
||||
add_factor(current_factor)
|
||||
all_values.append(tuple(value))
|
||||
|
||||
return tuple(all_values)
|
||||
|
||||
|
||||
class SdyShardingRule:
|
||||
"""A representation for Shardy sharding rule.
|
||||
|
||||
A SdyShardingRule includes an Enisum notation like string and an optional
|
||||
list of factor sizes. A factor is a name in the Einsum notation. If a factor
|
||||
is only used in compound factors, its size must be specified.
|
||||
|
||||
SdyShardingRule examples:
|
||||
|
||||
* Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k')
|
||||
* Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k')
|
||||
* A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j')
|
||||
* Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2)
|
||||
* An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...')
|
||||
"""
|
||||
|
||||
def __init__(self, rule: str, **factor_sizes):
|
||||
"""Constructs a SdyShardingRule object from the Einsum notation like string.
|
||||
|
||||
This is done by verifying that the input Einsum notation like string and
|
||||
with optional factor sizes represents a valid sharding rule and converting
|
||||
it to an internal representation.
|
||||
|
||||
Args:
|
||||
rule: The Einsum notation like string for an operation.
|
||||
**factor_sizes: The optional factor sizes.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is any problem with the rule or factor_sizes.
|
||||
"""
|
||||
if not isinstance(rule, str):
|
||||
raise TypeError(f"rule must be a str, but got {type(rule)}")
|
||||
if not all(isinstance(size, int) for size in factor_sizes.values()):
|
||||
raise TypeError(
|
||||
f"factor_sizes must be a dict of str to int, but got {factor_sizes}")
|
||||
|
||||
# Replace ... with a single char to simplify parsing.
|
||||
if _ELLIPSIS in rule:
|
||||
raise ValueError(f"Unknown character '{_ELLIPSIS}'")
|
||||
if "." in rule:
|
||||
rule = rule.replace("...", _ELLIPSIS)
|
||||
if "." in rule:
|
||||
raise ValueError("Character '.' must be used inside ellipsis '...'")
|
||||
|
||||
try:
|
||||
operands, results = rule.split("->")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"There is no -> in rule: '{rule}'") from e
|
||||
|
||||
self.operands = _parse_values(operands)
|
||||
self.results = _parse_values(results)
|
||||
|
||||
# Find all factors and mark whether their size can be inferred.
|
||||
factors_inferrable = dict()
|
||||
for value in self.operands + self.results:
|
||||
for dim in value:
|
||||
if dim == _ELLIPSIS:
|
||||
continue
|
||||
if isinstance(dim, str):
|
||||
factors_inferrable[dim] = True
|
||||
else:
|
||||
for factor in dim:
|
||||
if factor not in factors_inferrable.keys():
|
||||
factors_inferrable[factor] = False
|
||||
|
||||
# Check that factors in factor_sizes are used in the rule.
|
||||
for factor in factor_sizes:
|
||||
if factor not in factors_inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} is not used in the rule, but size is provided")
|
||||
|
||||
# Check that factors that are used for a whole dimension aren't in
|
||||
# factor_sizes and factors that are never used for a whole dimension are
|
||||
# in factor_sizes.
|
||||
for factor, inferrable in factors_inferrable.items():
|
||||
if factor not in factor_sizes and not inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} is only used in compound factors; must specify"
|
||||
" its size")
|
||||
if factor in factor_sizes and inferrable:
|
||||
raise ValueError(
|
||||
f"Factor {factor} represents a whole dimension; do not specify its"
|
||||
" size")
|
||||
|
||||
self.factor_sizes = factor_sizes
|
||||
|
||||
def __str__(self):
|
||||
return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})"
|
||||
|
||||
def build(
|
||||
self,
|
||||
operand_types: list[ir.Type],
|
||||
result_types: list[ir.Type],) -> ir.Attribute:
|
||||
"""Builds the MLIR representation for the sharding rule.
|
||||
|
||||
This is done by verifying that the rule is consistent with the types of
|
||||
the operation and converting the Einsum notation like string to
|
||||
OpShardingRuleAttr.
|
||||
"""
|
||||
if len(self.operands) != len(operand_types):
|
||||
raise ValueError(
|
||||
f"Sharding rule has {len(self.operands)} operands, but the operation"
|
||||
f" has {len(operand_types)} operands"
|
||||
)
|
||||
if len(self.results) != len(result_types):
|
||||
raise ValueError(
|
||||
f"Sharding rule has {len(self.results)} results, but the operation"
|
||||
f" has {len(result_types)} results"
|
||||
)
|
||||
|
||||
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
|
||||
types = operand_types + result_types
|
||||
UNKNOWN = -1 # Representation for unknown factor size or factor index.
|
||||
|
||||
def get_message_for_value(i):
|
||||
if i >= len(operand_types):
|
||||
return f"{i - len(operand_types)}th result"
|
||||
else:
|
||||
return f"{i}th operand"
|
||||
|
||||
def get_rank_for_value(i):
|
||||
return ir.ShapedType(types[i]).rank
|
||||
|
||||
def get_size_for_value_dim(i, j):
|
||||
return ir.ShapedType(types[i]).shape[j]
|
||||
|
||||
def add_factor(factor, size):
|
||||
"""Adds a factor to factors_to_indices_sizes.
|
||||
|
||||
`size` may be a dimensions size, a user specified factor size, or UNKNOWN
|
||||
if a factor is first used as in a compound factor and then used for a
|
||||
whole dimension.
|
||||
"""
|
||||
factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN])
|
||||
if factor_index != UNKNOWN:
|
||||
# Not the first time seeing the factor.
|
||||
if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size:
|
||||
factor_or_batching_dim = (
|
||||
f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor
|
||||
else f"Batching dimension {factor[1:]}")
|
||||
raise ValueError(
|
||||
f"{factor_or_batching_dim} corresponds to two sizes:"
|
||||
f" {factor_size} and {size}")
|
||||
if size != UNKNOWN and factor_size == UNKNOWN:
|
||||
factors_to_indices_sizes[factor] = [factor_index, size]
|
||||
else:
|
||||
# First time seeing the factor.
|
||||
factor_index = len(factors_to_indices_sizes)
|
||||
factors_to_indices_sizes[factor] = [factor_index, size]
|
||||
|
||||
def add_batching_dim_factor(batch_dim_order, factor_size):
|
||||
ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order)
|
||||
add_factor(ellipsis_batch_dim_name, factor_size)
|
||||
|
||||
def build_dim_mapping_for_compound_factors(i, j, factors):
|
||||
accumulated_size = 1
|
||||
all_indices = []
|
||||
for factor in factors:
|
||||
factor_index, factor_size = factors_to_indices_sizes[factor]
|
||||
accumulated_size *= factor_size
|
||||
all_indices.append(factor_index)
|
||||
|
||||
dim_size = get_size_for_value_dim(i, j)
|
||||
if accumulated_size != dim_size:
|
||||
raise ValueError(
|
||||
f"{get_message_for_value(i)} actual size {dim_size} doesn't match"
|
||||
f" the size {accumulated_size} derived from the compound factors"
|
||||
f" {factors}")
|
||||
|
||||
return sdy.DimMappingAttr.get(factor_indices=all_indices)
|
||||
|
||||
# Add factors and their sizes in the order they appear in the rule,
|
||||
# including the batching dimensions represented by ellipsis.
|
||||
ellipsis_rank = None
|
||||
for i, value in enumerate(self.operands + self.results):
|
||||
if value and value[0] == _ELLIPSIS:
|
||||
has_ellipsis = True
|
||||
value = value[1:]
|
||||
else:
|
||||
has_ellipsis = False
|
||||
rule_rank = len(value)
|
||||
op_rank = get_rank_for_value(i)
|
||||
# The number of dimensions represented by ellipsis.
|
||||
current_ellipsis_rank = 0
|
||||
if has_ellipsis and op_rank > rule_rank:
|
||||
current_ellipsis_rank = op_rank - rule_rank
|
||||
if has_ellipsis:
|
||||
if ellipsis_rank is None:
|
||||
ellipsis_rank = current_ellipsis_rank
|
||||
elif ellipsis_rank != current_ellipsis_rank:
|
||||
raise ValueError(
|
||||
"Ellipsis represents different number of leading dimensions"
|
||||
f" {ellipsis_rank} and {current_ellipsis_rank}")
|
||||
rule_rank += current_ellipsis_rank
|
||||
if rule_rank != op_rank:
|
||||
msg = get_message_for_value(i)
|
||||
raise ValueError(
|
||||
f"Sharding rule {msg} has rank {rule_rank}, but the operation"
|
||||
f" {msg} has rank {op_rank}")
|
||||
|
||||
for j in range(current_ellipsis_rank):
|
||||
add_batching_dim_factor(j, get_size_for_value_dim(i, j))
|
||||
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
add_factor(
|
||||
dim, get_size_for_value_dim(i, j + current_ellipsis_rank))
|
||||
else:
|
||||
for factor in dim:
|
||||
add_factor(factor, self.factor_sizes.get(factor, UNKNOWN))
|
||||
|
||||
# Build the tensor mappings for each operand and result.
|
||||
tensor_mappings = []
|
||||
for i, value in enumerate(self.operands + self.results):
|
||||
dim_mappings = []
|
||||
|
||||
if value and value[0] == _ELLIPSIS:
|
||||
value = value[1:]
|
||||
if ellipsis_rank is None:
|
||||
current_ellipsis_rank = 0
|
||||
else:
|
||||
current_ellipsis_rank = ellipsis_rank
|
||||
else:
|
||||
current_ellipsis_rank = 0
|
||||
|
||||
for j in range(current_ellipsis_rank):
|
||||
dim_mappings.append(
|
||||
sdy.DimMappingAttr.get(factor_indices=[
|
||||
factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]]))
|
||||
|
||||
for j, dim in enumerate(value):
|
||||
if isinstance(dim, str):
|
||||
dim_mappings.append(
|
||||
sdy.DimMappingAttr.get(
|
||||
factor_indices=[factors_to_indices_sizes[dim][0]]))
|
||||
else:
|
||||
dim_mappings.append(
|
||||
build_dim_mapping_for_compound_factors(
|
||||
i, j + current_ellipsis_rank, dim))
|
||||
|
||||
tensor_mappings.append(
|
||||
sdy.TensorMappingAttr.get(dim_mappings=dim_mappings))
|
||||
|
||||
op_sharding_rule = sdy.OpShardingRuleAttr.get(
|
||||
factor_sizes=[item[1] for item in factors_to_indices_sizes.values()],
|
||||
operand_mappings=tensor_mappings[0:len(operand_types)],
|
||||
result_mappings=tensor_mappings[len(operand_types):])
|
||||
return op_sharding_rule
|
10
tests/BUILD
10
tests/BUILD
@ -1557,6 +1557,16 @@ jax_multiplatform_test(
|
||||
tags = ["multiaccelerator"],
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
name = "custom_partitioning_sharding_rule_test",
|
||||
srcs = ["custom_partitioning_sharding_rule_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:experimental",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
396
tests/custom_partitioning_sharding_rule_test.py
Normal file
396
tests/custom_partitioning_sharding_rule_test.py
Normal file
@ -0,0 +1,396 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.custom_partitioning_sharding_rule import SdyShardingRule
|
||||
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
||||
|
||||
|
||||
class SdyShardingRuleTest(jtu.JaxTestCase):
|
||||
|
||||
def test_rule_is_not_a_str(self):
|
||||
with self.assertRaisesRegex(TypeError, "rule must be a str"):
|
||||
SdyShardingRule(1)
|
||||
|
||||
def test_factor_sizes_is_not_a_proper_dict(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "factor_sizes must be a dict of str to int"):
|
||||
SdyShardingRule("i->j", i="j")
|
||||
|
||||
def test_sharding_rule_ellipsis_not_complete(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Character '.' must be used inside ellipsis '...'"):
|
||||
SdyShardingRule(".i -> j")
|
||||
|
||||
def test_sharding_rule_invalid_factor_name(self):
|
||||
with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"):
|
||||
SdyShardingRule("2i -> j")
|
||||
|
||||
def test_sharding_rule_missing_results(self):
|
||||
with self.assertRaisesRegex(ValueError, "There is no -> in rule"):
|
||||
SdyShardingRule("i")
|
||||
|
||||
def test_sharding_rule_inbalenced_brackets(self):
|
||||
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
|
||||
SdyShardingRule("i j, k)->j")
|
||||
|
||||
def test_sharding_rule_inbalenced_brackets2(self):
|
||||
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
|
||||
SdyShardingRule("i (j k->j")
|
||||
|
||||
def test_sharding_rule_empty_compound_dim(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Brackets should contain at least two factors"):
|
||||
SdyShardingRule("i ( ) j k->j")
|
||||
|
||||
def test_sharding_rule_one_factorcompound_dim(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Brackets should contain at least two factors"):
|
||||
SdyShardingRule("i (j ) k->j")
|
||||
|
||||
def test_sharding_rule_nested_brackets(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Compound factors should be one level"):
|
||||
SdyShardingRule("i (j (k))->j")
|
||||
|
||||
def test_sharding_rule_unknown_char(self):
|
||||
with self.assertRaisesRegex(ValueError, "Unknown character"):
|
||||
SdyShardingRule("i; j->j")
|
||||
|
||||
def test_sharding_rule_unknown_single_char_ellipse(self):
|
||||
with self.assertRaisesRegex(ValueError, "Unknown character"):
|
||||
SdyShardingRule("…j->…j")
|
||||
|
||||
def test_sharding_rule_ellipsis_not_leading_dim(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
|
||||
SdyShardingRule("i ... -> j")
|
||||
|
||||
def test_sharding_rule_ellipsis_inside_compound_dim(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
|
||||
SdyShardingRule("i, (..., j) -> j")
|
||||
|
||||
def test_sharding_rule_scalar_operand_scalar_result(self):
|
||||
rule = SdyShardingRule("->")
|
||||
self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})")
|
||||
|
||||
def test_sharding_rule_one_scalar_operand(self):
|
||||
rule = SdyShardingRule("i j, , k->j")
|
||||
self.assertEqual(
|
||||
str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})")
|
||||
|
||||
def test_sharding_rule_factor_size_not_used(self):
|
||||
with self.assertRaisesRegex(ValueError, "Factor k is not used"):
|
||||
SdyShardingRule("i->j", k=10)
|
||||
|
||||
def test_sharding_rule_factor_size_not_necessary(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Factor i represents a whole dimension; do not specify its size"):
|
||||
SdyShardingRule("i->j", i=10)
|
||||
|
||||
def test_sharding_rule_compound_factor_size_not_necessary(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Factor i represents a whole dimension; do not specify its size"):
|
||||
SdyShardingRule("(i j) -> i", i=10, j=20)
|
||||
|
||||
def test_sharding_rule_factor_sizes_missing(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Factor k is only used in compound factors; must specify its size"):
|
||||
SdyShardingRule("i j -> (j k)")
|
||||
|
||||
def test_sharding_rule_factor_elementwise_add(self):
|
||||
rule = SdyShardingRule("... i j, ...i j -> ...i j")
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i',"
|
||||
" 'j'),), {})")
|
||||
|
||||
def test_sharding_rule_factor_vector_scalar_add(self):
|
||||
rule = SdyShardingRule("...i, -> ...i")
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})")
|
||||
|
||||
def test_sharding_rule_factor_reshape_combining(self):
|
||||
rule = SdyShardingRule("i j -> (i j)")
|
||||
self.assertEqual(
|
||||
str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})")
|
||||
|
||||
def test_sharding_rule_factor_reshape_reordering(self):
|
||||
rule = SdyShardingRule("(j i) -> (i j)", i=10, j=20)
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':"
|
||||
" 20})")
|
||||
|
||||
def test_sharding_rule_factor_compound_then_individual(self):
|
||||
rule = SdyShardingRule("(i j) (j k) i -> j k")
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})")
|
||||
|
||||
def test_sharding_rule_factor_individual_then_compound(self):
|
||||
rule = SdyShardingRule("i j k -> (i j) (j k)")
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})")
|
||||
|
||||
def test_sharding_rule_factor_infer_k(self):
|
||||
rule = SdyShardingRule("_i (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20)
|
||||
self.assertEqual(
|
||||
str(rule),
|
||||
"SdyShardingRule((('_i', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))"
|
||||
",), {'k': 10, 'm': 10, 'bar_24': 20})")
|
||||
|
||||
|
||||
class SdyShardingRuleConversionTest(jtu.JaxTestCase):
|
||||
|
||||
def run(self, result=None):
|
||||
with ir.Context() as ctx, ir.Location.unknown(ctx):
|
||||
sdy.register_dialect(ctx)
|
||||
stablehlo.register_dialect(ctx)
|
||||
module = ir.Module.create()
|
||||
with ir.InsertionPoint(module.body):
|
||||
super().run(result)
|
||||
|
||||
def get_tensor_type(self, shape):
|
||||
return ir.RankedTensorType.get(shape, ir.F32Type.get())
|
||||
|
||||
def create_tensor_value(self, shape):
|
||||
return ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type(shape)],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("dummy_target"))
|
||||
).result
|
||||
|
||||
def test_conversion_rule_op_mismatch_in_operands_num(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16, 32))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 32))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
||||
rule = SdyShardingRule("i j-> i j")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Sharding rule has 1 operands, but the operation has 2 operands"):
|
||||
rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
|
||||
def test_conversion_rule_op_mismatch_in_operands_rank(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16, 32))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 32))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
||||
rule = SdyShardingRule("i j, i j k-> i j")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Sharding rule 1th operand has rank 3, but the operation 1th "
|
||||
"operand has rank 2"):
|
||||
rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
|
||||
def test_conversion_rule_op_mismatch_in_results_num(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16, 32))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 32))],
|
||||
operands=[opnd0,
|
||||
opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
||||
rule = SdyShardingRule("i j, i j -> i j, i j")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Sharding rule has 2 results, but the operation has 1 results"):
|
||||
rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
|
||||
def test_conversion_rule_op_mismatch_in_results_dim(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16, 32))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 32))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = SdyShardingRule("i j, i j -> i j k")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Sharding rule 0th result has rank 3, but the operation 0th "
|
||||
"result has rank 2"):
|
||||
rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
|
||||
def test_conversion_factor_has_two_sizes(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16, 32))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 64))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = SdyShardingRule("i j, i j -> i j")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Factor j corresponds to two sizes: 32 and 64"):
|
||||
rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
|
||||
def test_conversion_batching_dim_has_two_sizes(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16, 32))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 64))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = SdyShardingRule("..., ... -> ...")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Batching dimension 1 corresponds to two sizes: 32 and 64"):
|
||||
rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,],)
|
||||
|
||||
def test_conversion_compound_dimension_size_mismatch(self):
|
||||
opnd = self.create_tensor_value((2, 4))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((9,))],
|
||||
operands=[opnd,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = SdyShardingRule("i j -> (i j)")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"0th result actual size 9 doesn't match the size 8 derived from the"
|
||||
" compound factors"):
|
||||
rule.build(
|
||||
[result.operands[0].type],
|
||||
[result.result.type,])
|
||||
|
||||
def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16,))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 32))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = SdyShardingRule("..., ... -> ...")
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Ellipsis represents different number of leading dimensions 2 and 1"):
|
||||
rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
|
||||
def test_conversion_elementwise_rule_scalar_instance(self):
|
||||
opnd0 = self.create_tensor_value(())
|
||||
opnd1 = self.create_tensor_value(())
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type(())],
|
||||
operands=[opnd0, opnd1],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
||||
rule = SdyShardingRule("..., ... -> ...")
|
||||
mlir_rule = rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
self.assertEqual(
|
||||
str(mlir_rule),
|
||||
"#sdy.op_sharding_rule<([], [])->([])>")
|
||||
|
||||
def test_conversion_elementwise_rule_2D_instance(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((16, 32))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 32))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
||||
rule = SdyShardingRule("..., ... -> ...")
|
||||
mlir_rule = rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
self.assertEqual(
|
||||
str(mlir_rule),
|
||||
"#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>")
|
||||
|
||||
def test_conversion_vector_scalar_add_2D_instance(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value(())
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 32))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
||||
rule = SdyShardingRule("..., -> ...")
|
||||
mlir_rule = rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
self.assertEqual(
|
||||
str(mlir_rule),
|
||||
"#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>")
|
||||
|
||||
def test_conversion_reshape_rule(self):
|
||||
opnd0 = self.create_tensor_value((2, 4))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((8,))],
|
||||
operands=[opnd0,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = SdyShardingRule("i j -> (i j)")
|
||||
mlir_rule = rule.build(
|
||||
[result.operands[0].type],
|
||||
[result.result.type,])
|
||||
self.assertEqual(
|
||||
str(mlir_rule),
|
||||
"#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>")
|
||||
|
||||
def test_conversion_contracting_dim_matmul(self):
|
||||
opnd0 = self.create_tensor_value((16, 32))
|
||||
opnd1 = self.create_tensor_value((32, 8))
|
||||
result = ir.Operation.create(
|
||||
"stablehlo.custom_call",
|
||||
results=[self.get_tensor_type((16, 8))],
|
||||
operands=[opnd0, opnd1,],
|
||||
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
||||
rule = SdyShardingRule("... contracting_dim, contracting_dim k -> ... k")
|
||||
mlir_rule = rule.build(
|
||||
[result.operands[0].type, result.operands[1].type],
|
||||
[result.result.type,])
|
||||
self.assertEqual(
|
||||
str(mlir_rule),
|
||||
"#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user