[shape_poly] Fix the lowering for symbolic dimension expressions for division

The symbolic dimension expression use the Python semantics for division and remainder, while StableHLO is slightly different.

PiperOrigin-RevId: 510056597
This commit is contained in:
George Necula 2023-02-15 23:50:44 -08:00 committed by jax authors
parent d0b42f2ce8
commit 454e4de524
2 changed files with 52 additions and 11 deletions

View File

@ -605,8 +605,22 @@ class DimExprEvaluator:
def __divmod__(self, divisor: Union[np.int32, np.int64, DimExprEvaluator]):
if not isinstance(divisor, DimExprEvaluator):
divisor = DimExprEvaluator(ir_constant(divisor))
return (DimExprEvaluator(hlo.DivOp(self.value, divisor.value).result),
DimExprEvaluator(hlo.RemOp(self.value, divisor.value).result))
# Quotient
raw_quotient = hlo.DivOp(self.value, divisor.value)
raw_remainder = hlo.RemOp(self.value, divisor.value)
ops_different_sign = compare_hlo(hlo.SignOp(self.value),
hlo.SignOp(divisor.value),
"NE", "SIGNED")
rem_ne_zero = compare_hlo(raw_remainder, ir_constant(np.int64(0)),
"NE", "SIGNED")
must_adjust = hlo.AndOp(ops_different_sign, rem_ne_zero)
quotient = hlo.SelectOp(must_adjust,
hlo.SubtractOp(raw_quotient, ir_constant(np.int64(1))),
raw_quotient)
# Remainder
remainder = hlo.SubtractOp(self.value, hlo.MulOp(divisor.value, quotient))
return (DimExprEvaluator(quotient.result),
DimExprEvaluator(remainder.result))
def __rdivmod__(self, dividend: Union[np.int32, np.int64]):
return DimExprEvaluator(ir_constant(dividend)).__divmod__(self)

View File

@ -640,15 +640,42 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x),
f_jax(x))
def test_non_trivial_dim_expr(self):
check_shape_poly(
self,
lambda x: (x[0] + x.shape[0] + x.shape[0] * x.shape[0] + (5 * x.shape[0]) +
x.shape[0] // 2 + (5 + x.shape[0]) // x.shape[0] +
17 // x.shape[0] +
x.shape[0] % 3 + 17 % x.shape[0]),
arg_descriptors=[RandArg((3,), np.int64)],
poly_axes=[0])
@parameterized.named_parameters([
dict(testcase_name=f"_expr={name}", expr=expr)
for name, expr in [
("d + 2", lambda d: d + 2),
("2 - d", lambda d: 2 - d),
("d * 2", lambda d: d * 2),
("d * d", lambda d: d * d),
("(- d) * d", lambda d: (- d) * d),
("d * d - d", lambda d: d * d - d),
# Division
("d // 2", lambda d: d // 2),
("(d + 1) // 2", lambda d: (d + 1) // 2),
("d // -2", lambda d: d // -2),
("(d + 1) // -2", lambda d: (d + 1) // -2),
("(-d) // 2", lambda d: (-d) // 2),
("(-d - 1) // 2", lambda d: (-d - 1) // 2),
("(-d) // -2", lambda d: (-d) // -2),
("(-d - 1) // -2", lambda d: (-d - 1) // -2),
# Remainder
("d % 2", lambda d: d % 2),
("(d + 1) % 2", lambda d: (d + 1) % 2),
("d % -2", lambda d: d % -2),
("(d + 1) % -2", lambda d: (d + 1) % -2),
("(-d) % 2", lambda d: (-d) % 2),
("(-d - 1) % 2", lambda d: (-d - 1) % 2),
("(-d) % -2", lambda d: (-d) % -2),
("(-d - 1) % -2", lambda d: (-d - 1) % -2),
]
])
def test_non_trivial_dim_expr(self, expr=lambda d: d % -2):
# Check the lowering for shape expressions
check_shape_poly(
self,
lambda x: x[0] * 0 + expr(x.shape[0]),
arg_descriptors=[RandArg((3,), np.int64)],
poly_axes=[0])
def test_static_shape_result(self):
"""The result has static shape."""