Merge pull request #7403 from gnecula:tf_random

PiperOrigin-RevId: 387562021
This commit is contained in:
jax authors 2021-07-29 05:23:48 -07:00
commit c4602475ca
2 changed files with 50 additions and 5 deletions

View File

@ -320,6 +320,19 @@ class _DimPolynomial(dict):
def __rfloordiv__(self, other):
return _ensure_poly(other).__floordiv__(self)
def __truediv__(self, divisor: DimSize):
# Used for "/"
q, r = self.divmod(divisor)
if r != 0:
raise InconclusiveDimensionOperation(
f"Dimension polynomial '{self}' is not a multiple of '{divisor}'")
return q
def __rtruediv__(self, dividend: DimSize):
# Used for "/", when dividend is not a _DimPolynomial
raise InconclusiveDimensionOperation(
f"Division of '{dividend}' by dimension polynomial '{self}' is not supported")
def __mod__(self, divisor: DimSize) -> int:
return self.divmod(divisor)[1]

View File

@ -245,14 +245,45 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
(a, b, None, None),
(3 * a, 2, None, None),
(2 * a * b + b * b, a + b, None, None),
(3, a, None, None),
])
def test_poly_divmod(self, dividend, quotient, divisor, remainder):
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
if quotient is None:
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Dimension polynomial .* is not a multiple of .*"):
dividend.divmod(divisor)
divmod(dividend, divisor)
else:
self.assertEqual((quotient, remainder), dividend.divmod(divisor))
self.assertEqual((quotient, remainder), divmod(dividend, divisor))
@parameterized.named_parameters(
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}",
dividend=dividend, divisor=divisor, quotient=quotient)
for dividend, divisor, quotient in [
(a, 1, a),
(3 * a, 3, a),
(3 * a + 3, 3, a + 1),
(3 * a + 2, 3, None),
(3 * a + 5, 3, None),
(3 * a - 2, 3, None),
(3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b),
(a * a - b * b, a + b, a - b),
(a, b, None),
(3 * a, 2, None),
(2 * a * b + b * b, a + b, None),
])
def test_poly_truediv(self, *, dividend, divisor, quotient):
if quotient is None:
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Dimension polynomial .* is not a multiple of .*"):
dividend / divisor
else:
self.assertEqual(quotient, dividend / divisor)
def test_poly_truediv_error(self):
a, = shape_poly.parse_spec("a,", (2,))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Division of '3' by dimension polynomial .* is not supported"):
3 / a
def test_dilate_shape(self):
"""0 if d == 0 else 1 + dilation * (d - 1))"""
@ -844,8 +875,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))
with self.assertRaisesRegex(
TypeError,
re.escape("unsupported operand type(s) for /: 'TensorFlowTracer' and '_DimPolynomial'")):
core.InconclusiveDimensionOperation,
re.compile("Division of .* by dimension polynomial .* is not supported",
re.DOTALL)):
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))