2024-12-05 11:32:43 -08:00
|
|
|
# Copyright 2024 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from jax._src import test_util as jtu
|
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
from jax._src.lib.mlir.dialects import sdy
|
2024-12-10 13:05:05 -08:00
|
|
|
from jax._src.custom_partitioning_sharding_rule import ArrayMapping, BATCHING, CompoundFactor, sdy_sharding_rule_to_mlir, str_to_sdy_sharding_rule, SdyShardingRule
|
2024-12-05 11:32:43 -08:00
|
|
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
|
|
|
|
|
|
|
|
|
|
class SdyShardingRuleTest(jtu.JaxTestCase):
|
2024-12-10 13:05:05 -08:00
|
|
|
def test_compound_factor_not_enough_factors(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "A compound factor should contain at least two factors"):
|
|
|
|
CompoundFactor("i")
|
|
|
|
|
|
|
|
def test_compound_factor_batching_now_allowed(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Ellipsis can't be used in a compound factor"):
|
|
|
|
CompoundFactor(BATCHING, "i")
|
|
|
|
|
|
|
|
def test_compound_factor_element_not_a_str(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Each element of CompoundFactor must be a str"):
|
|
|
|
CompoundFactor("i", 2)
|
|
|
|
|
|
|
|
def test_compound_factor_str(self):
|
|
|
|
c = CompoundFactor("i", "j", "k")
|
|
|
|
self.assertEqual(str(c), "('i', 'j', 'k')")
|
|
|
|
|
|
|
|
def test_value_mapping_element_not_a_str_or_compound_factor(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Each element of ArrayMapping must be a str or CompoundFactor"):
|
|
|
|
ArrayMapping(CompoundFactor("i", "j"), 3)
|
|
|
|
|
|
|
|
def test_value_mapping_factor_name_not_start_with_letter(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"):
|
|
|
|
ArrayMapping("3i", "j")
|
|
|
|
|
|
|
|
def test_value_mapping_ellipsis_not_first(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Ellipsis can only be used at the beginning of a dimension"):
|
|
|
|
ArrayMapping("i_j", BATCHING)
|
|
|
|
|
|
|
|
def test_value_mapping_str(self):
|
2025-01-23 08:20:04 -08:00
|
|
|
v = ArrayMapping(f"{BATCHING}2", "m", CompoundFactor("i", "j"), "k")
|
|
|
|
self.assertEqual(str(v), f"('{BATCHING}2', 'm', ('i', 'j'), 'k')")
|
2024-12-10 13:05:05 -08:00
|
|
|
|
|
|
|
def test_sdy_sharding_rule_factor_size_not_used(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Factor k is not used"):
|
|
|
|
SdyShardingRule(("i",), ("j",), k=10)
|
|
|
|
|
|
|
|
def test_sdy_sharding_rule_factor_sizes_missing(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Factor k is only used in compound factors; must specify its size"):
|
|
|
|
SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")),
|
|
|
|
(ArrayMapping(CompoundFactor("j", "k")),))
|
|
|
|
|
|
|
|
def test_sdy_sharding_rule_factor_size_not_necessary(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Factor i represents a whole dimension; do not specify its size"):
|
|
|
|
SdyShardingRule((ArrayMapping("i"),), (ArrayMapping("j"),), i=10)
|
|
|
|
|
|
|
|
def test_sdy_sharding_rule_compound_factor_size_not_necessary(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Factor i represents a whole dimension; do not specify its size"):
|
|
|
|
SdyShardingRule((ArrayMapping(CompoundFactor("i", "j")),),
|
|
|
|
(ArrayMapping("i"),), i=10, j=20)
|
|
|
|
|
|
|
|
def test_sdy_sharding_rule_str(self):
|
|
|
|
r = SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")),
|
|
|
|
(ArrayMapping(CompoundFactor("j", "k")),), k=10)
|
|
|
|
self.assertEqual(str(r), "SdyShardingRule((('i',), ('j',)), ((('j', 'k'),),), {'k': 10})")
|
|
|
|
|
|
|
|
|
|
|
|
class StrToSdyShardingRuleTest(jtu.JaxTestCase):
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_rule_is_not_a_str(self):
|
|
|
|
with self.assertRaisesRegex(TypeError, "rule must be a str"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule(1)
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_factor_sizes_is_not_a_proper_dict(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError, "factor_sizes must be a dict of str to int"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i->j", i="j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_ellipsis_not_complete(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Character '.' must be used inside ellipsis '...'"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule(".i -> j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_invalid_factor_name(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("2i -> j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_missing_results(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "There is no -> in rule"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_inbalenced_brackets(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i j, k)->j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_inbalenced_brackets2(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Brackets are not balanced"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i (j k->j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_empty_compound_dim(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Brackets should contain at least two factors"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i ( ) j k->j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_one_factorcompound_dim(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Brackets should contain at least two factors"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i (j ) k->j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_nested_brackets(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Compound factors should be one level"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i (j (k))->j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_unknown_char(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Unknown character"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i; j->j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_unknown_single_char_ellipse(self):
|
|
|
|
with self.assertRaisesRegex(ValueError, "Unknown character"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("…j->…j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_ellipsis_not_leading_dim(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i ... -> j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_ellipsis_inside_compound_dim(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Ellipsis can only be used at the beginning of a dimension"):
|
2024-12-10 13:05:05 -08:00
|
|
|
str_to_sdy_sharding_rule("i, (..., j) -> j")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_scalar_operand_scalar_result(self):
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("->")
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})")
|
|
|
|
|
|
|
|
def test_sharding_rule_one_scalar_operand(self):
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j, , k->j")
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})")
|
|
|
|
|
|
|
|
def test_sharding_rule_factor_elementwise_add(self):
|
2025-01-23 08:20:04 -08:00
|
|
|
# An ellipsis without a number ... is treated as the same as ...0.
|
|
|
|
rule = str_to_sdy_sharding_rule("...0 i j, ...1 i j -> ...i j")
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule),
|
2025-01-23 08:20:04 -08:00
|
|
|
"SdyShardingRule((('…0', 'i', 'j'), ('…1', 'i', 'j')), (('…0', 'i',"
|
2024-12-05 11:32:43 -08:00
|
|
|
" 'j'),), {})")
|
|
|
|
|
|
|
|
def test_sharding_rule_factor_vector_scalar_add(self):
|
2025-01-23 08:20:04 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("...87 i, -> ...87 i")
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule),
|
2025-01-23 08:20:04 -08:00
|
|
|
"SdyShardingRule((('…87', 'i'), ()), (('…87', 'i'),), {})")
|
2024-12-05 11:32:43 -08:00
|
|
|
|
|
|
|
def test_sharding_rule_factor_reshape_combining(self):
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j -> (i j)")
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})")
|
|
|
|
|
|
|
|
def test_sharding_rule_factor_reshape_reordering(self):
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("(j i) -> (i j)", i=10, j=20)
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule),
|
|
|
|
"SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':"
|
|
|
|
" 20})")
|
|
|
|
|
|
|
|
def test_sharding_rule_factor_compound_then_individual(self):
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("(i j) (j k) i -> j k")
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule),
|
|
|
|
"SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})")
|
|
|
|
|
|
|
|
def test_sharding_rule_factor_individual_then_compound(self):
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j k -> (i j) (j k)")
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule),
|
|
|
|
"SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})")
|
|
|
|
|
|
|
|
def test_sharding_rule_factor_infer_k(self):
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i_ (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20)
|
2024-12-05 11:32:43 -08:00
|
|
|
self.assertEqual(
|
|
|
|
str(rule),
|
2024-12-10 13:05:05 -08:00
|
|
|
"SdyShardingRule((('i_', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))"
|
2024-12-05 11:32:43 -08:00
|
|
|
",), {'k': 10, 'm': 10, 'bar_24': 20})")
|
|
|
|
|
|
|
|
|
|
|
|
class SdyShardingRuleConversionTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def run(self, result=None):
|
|
|
|
with ir.Context() as ctx, ir.Location.unknown(ctx):
|
|
|
|
sdy.register_dialect(ctx)
|
|
|
|
stablehlo.register_dialect(ctx)
|
|
|
|
module = ir.Module.create()
|
|
|
|
with ir.InsertionPoint(module.body):
|
|
|
|
super().run(result)
|
|
|
|
|
|
|
|
def get_tensor_type(self, shape):
|
|
|
|
return ir.RankedTensorType.get(shape, ir.F32Type.get())
|
|
|
|
|
|
|
|
def create_tensor_value(self, shape):
|
|
|
|
return ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type(shape)],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("dummy_target"))
|
|
|
|
).result
|
|
|
|
|
|
|
|
def test_conversion_rule_op_mismatch_in_operands_num(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j-> i j")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Sharding rule has 1 operands, but the operation has 2 operands"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
|
|
|
def test_conversion_rule_op_mismatch_in_operands_rank(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j, i j k-> i j")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Sharding rule 1th operand has rank 3, but the operation 1th "
|
|
|
|
"operand has rank 2"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
|
|
|
def test_conversion_rule_op_mismatch_in_results_num(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0,
|
|
|
|
opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j, i j -> i j, i j")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Sharding rule has 2 results, but the operation has 1 results"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
|
|
|
def test_conversion_rule_op_mismatch_in_results_dim(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j, i j -> i j k")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Sharding rule 0th result has rank 3, but the operation 0th "
|
|
|
|
"result has rank 2"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
|
|
|
def test_conversion_factor_has_two_sizes(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 64))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j, i j -> i j")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Factor j corresponds to two sizes: 32 and 64"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
|
|
|
def test_conversion_batching_dim_has_two_sizes(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 64))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("..., ... -> ...")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
2025-01-23 08:20:04 -08:00
|
|
|
"Batching dimension 0_1 corresponds to two sizes: 32 and 64"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,],)
|
|
|
|
|
2024-12-10 13:05:05 -08:00
|
|
|
def test_conversion_invalid_batching_dim(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
|
|
|
rule = str_to_sdy_sharding_rule("... i j k, ... i j k -> ... i j k")
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Sharding rule 0th operand has rank 3, but the operation 0th operand has rank 2"):
|
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
2024-12-05 11:32:43 -08:00
|
|
|
def test_conversion_compound_dimension_size_mismatch(self):
|
|
|
|
opnd = self.create_tensor_value((2, 4))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((9,))],
|
|
|
|
operands=[opnd,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j -> (i j)")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"0th result actual size 9 doesn't match the size 8 derived from the"
|
|
|
|
" compound factors"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
|
|
|
def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16,))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("..., ... -> ...")
|
2024-12-05 11:32:43 -08:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Ellipsis represents different number of leading dimensions 2 and 1"):
|
2024-12-10 13:05:05 -08:00
|
|
|
sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
|
2024-12-10 13:05:05 -08:00
|
|
|
def test_conversion_compound_then_individual(self):
|
|
|
|
opnd = self.create_tensor_value((8,))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((2,4))],
|
|
|
|
operands=[opnd,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
|
|
|
rule = str_to_sdy_sharding_rule("(i j) -> i j")
|
|
|
|
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
|
|
|
[result.operands[0].type],
|
|
|
|
[result.result.type,])
|
|
|
|
self.assertEqual(
|
|
|
|
str(mlir_rule),
|
|
|
|
"#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>")
|
|
|
|
|
2024-12-05 11:32:43 -08:00
|
|
|
def test_conversion_elementwise_rule_scalar_instance(self):
|
|
|
|
opnd0 = self.create_tensor_value(())
|
|
|
|
opnd1 = self.create_tensor_value(())
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type(())],
|
|
|
|
operands=[opnd0, opnd1],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("..., ... -> ...")
|
|
|
|
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
self.assertEqual(
|
|
|
|
str(mlir_rule),
|
|
|
|
"#sdy.op_sharding_rule<([], [])->([])>")
|
|
|
|
|
|
|
|
def test_conversion_elementwise_rule_2D_instance(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((16, 32))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("..., ... -> ...")
|
|
|
|
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
self.assertEqual(
|
|
|
|
str(mlir_rule),
|
|
|
|
"#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>")
|
|
|
|
|
|
|
|
def test_conversion_vector_scalar_add_2D_instance(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value(())
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 32))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")),)
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("..., -> ...")
|
|
|
|
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
self.assertEqual(
|
|
|
|
str(mlir_rule),
|
|
|
|
"#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>")
|
|
|
|
|
|
|
|
def test_conversion_reshape_rule(self):
|
|
|
|
opnd0 = self.create_tensor_value((2, 4))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((8,))],
|
|
|
|
operands=[opnd0,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("i j -> (i j)")
|
|
|
|
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type],
|
|
|
|
[result.result.type,])
|
|
|
|
self.assertEqual(
|
|
|
|
str(mlir_rule),
|
|
|
|
"#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>")
|
|
|
|
|
|
|
|
def test_conversion_contracting_dim_matmul(self):
|
|
|
|
opnd0 = self.create_tensor_value((16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((32, 8))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((16, 8))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
2024-12-10 13:05:05 -08:00
|
|
|
rule = str_to_sdy_sharding_rule("... contracting_dim, contracting_dim k -> ... k")
|
|
|
|
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
2024-12-05 11:32:43 -08:00
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
self.assertEqual(
|
|
|
|
str(mlir_rule),
|
|
|
|
"#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>")
|
|
|
|
|
|
|
|
|
2025-01-23 08:20:04 -08:00
|
|
|
def test_conversion_multiple_batching_groups(self):
|
|
|
|
opnd0 = self.create_tensor_value((4, 5, 16, 32))
|
|
|
|
opnd1 = self.create_tensor_value((6, 7, 8, 32, 16))
|
|
|
|
result = ir.Operation.create(
|
|
|
|
"stablehlo.custom_call",
|
|
|
|
results=[self.get_tensor_type((4, 5, 32, 16))],
|
|
|
|
operands=[opnd0, opnd1,],
|
|
|
|
attributes=dict(call_target_name=ir.StringAttr.get("foo")))
|
|
|
|
rule = str_to_sdy_sharding_rule("... j i, ...1 i j -> ...i j")
|
|
|
|
mlir_rule = sdy_sharding_rule_to_mlir(rule,
|
|
|
|
[result.operands[0].type, result.operands[1].type],
|
|
|
|
[result.result.type,])
|
|
|
|
self.assertEqual(
|
|
|
|
str(mlir_rule),
|
|
|
|
"#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>")
|
|
|
|
|
|
|
|
|
2024-12-05 11:32:43 -08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|