mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Fix the lowering for symbolic dimension expressions for division
The symbolic dimension expression use the Python semantics for division and remainder, while StableHLO is slightly different. PiperOrigin-RevId: 510056597
This commit is contained in:
parent
d0b42f2ce8
commit
454e4de524
@ -605,8 +605,22 @@ class DimExprEvaluator:
|
||||
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprEvaluator]):
|
||||
if not isinstance(divisor, DimExprEvaluator):
|
||||
divisor = DimExprEvaluator(ir_constant(divisor))
|
||||
return (DimExprEvaluator(hlo.DivOp(self.value, divisor.value).result),
|
||||
DimExprEvaluator(hlo.RemOp(self.value, divisor.value).result))
|
||||
# Quotient
|
||||
raw_quotient = hlo.DivOp(self.value, divisor.value)
|
||||
raw_remainder = hlo.RemOp(self.value, divisor.value)
|
||||
ops_different_sign = compare_hlo(hlo.SignOp(self.value),
|
||||
hlo.SignOp(divisor.value),
|
||||
"NE", "SIGNED")
|
||||
rem_ne_zero = compare_hlo(raw_remainder, ir_constant(np.int64(0)),
|
||||
"NE", "SIGNED")
|
||||
must_adjust = hlo.AndOp(ops_different_sign, rem_ne_zero)
|
||||
quotient = hlo.SelectOp(must_adjust,
|
||||
hlo.SubtractOp(raw_quotient, ir_constant(np.int64(1))),
|
||||
raw_quotient)
|
||||
# Remainder
|
||||
remainder = hlo.SubtractOp(self.value, hlo.MulOp(divisor.value, quotient))
|
||||
return (DimExprEvaluator(quotient.result),
|
||||
DimExprEvaluator(remainder.result))
|
||||
|
||||
def __rdivmod__(self, dividend: Union[np.int32, np.int64]):
|
||||
return DimExprEvaluator(ir_constant(dividend)).__divmod__(self)
|
||||
|
@ -640,15 +640,42 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x),
|
||||
f_jax(x))
|
||||
|
||||
def test_non_trivial_dim_expr(self):
|
||||
check_shape_poly(
|
||||
self,
|
||||
lambda x: (x[0] + x.shape[0] + x.shape[0] * x.shape[0] + (5 * x.shape[0]) +
|
||||
x.shape[0] // 2 + (5 + x.shape[0]) // x.shape[0] +
|
||||
17 // x.shape[0] +
|
||||
x.shape[0] % 3 + 17 % x.shape[0]),
|
||||
arg_descriptors=[RandArg((3,), np.int64)],
|
||||
poly_axes=[0])
|
||||
@parameterized.named_parameters([
|
||||
dict(testcase_name=f"_expr={name}", expr=expr)
|
||||
for name, expr in [
|
||||
("d + 2", lambda d: d + 2),
|
||||
("2 - d", lambda d: 2 - d),
|
||||
("d * 2", lambda d: d * 2),
|
||||
("d * d", lambda d: d * d),
|
||||
("(- d) * d", lambda d: (- d) * d),
|
||||
("d * d - d", lambda d: d * d - d),
|
||||
# Division
|
||||
("d // 2", lambda d: d // 2),
|
||||
("(d + 1) // 2", lambda d: (d + 1) // 2),
|
||||
("d // -2", lambda d: d // -2),
|
||||
("(d + 1) // -2", lambda d: (d + 1) // -2),
|
||||
("(-d) // 2", lambda d: (-d) // 2),
|
||||
("(-d - 1) // 2", lambda d: (-d - 1) // 2),
|
||||
("(-d) // -2", lambda d: (-d) // -2),
|
||||
("(-d - 1) // -2", lambda d: (-d - 1) // -2),
|
||||
# Remainder
|
||||
("d % 2", lambda d: d % 2),
|
||||
("(d + 1) % 2", lambda d: (d + 1) % 2),
|
||||
("d % -2", lambda d: d % -2),
|
||||
("(d + 1) % -2", lambda d: (d + 1) % -2),
|
||||
("(-d) % 2", lambda d: (-d) % 2),
|
||||
("(-d - 1) % 2", lambda d: (-d - 1) % 2),
|
||||
("(-d) % -2", lambda d: (-d) % -2),
|
||||
("(-d - 1) % -2", lambda d: (-d - 1) % -2),
|
||||
]
|
||||
])
|
||||
def test_non_trivial_dim_expr(self, expr=lambda d: d % -2):
|
||||
# Check the lowering for shape expressions
|
||||
check_shape_poly(
|
||||
self,
|
||||
lambda x: x[0] * 0 + expr(x.shape[0]),
|
||||
arg_descriptors=[RandArg((3,), np.int64)],
|
||||
poly_axes=[0])
|
||||
|
||||
def test_static_shape_result(self):
|
||||
"""The result has static shape."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user