mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #7403 from gnecula:tf_random
PiperOrigin-RevId: 387562021
This commit is contained in:
commit
c4602475ca
@ -320,6 +320,19 @@ class _DimPolynomial(dict):
|
||||
def __rfloordiv__(self, other):
|
||||
return _ensure_poly(other).__floordiv__(self)
|
||||
|
||||
def __truediv__(self, divisor: DimSize):
|
||||
# Used for "/"
|
||||
q, r = self.divmod(divisor)
|
||||
if r != 0:
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"Dimension polynomial '{self}' is not a multiple of '{divisor}'")
|
||||
return q
|
||||
|
||||
def __rtruediv__(self, dividend: DimSize):
|
||||
# Used for "/", when dividend is not a _DimPolynomial
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"Division of '{dividend}' by dimension polynomial '{self}' is not supported")
|
||||
|
||||
def __mod__(self, divisor: DimSize) -> int:
|
||||
return self.divmod(divisor)[1]
|
||||
|
||||
|
@ -245,14 +245,45 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
(a, b, None, None),
|
||||
(3 * a, 2, None, None),
|
||||
(2 * a * b + b * b, a + b, None, None),
|
||||
(3, a, None, None),
|
||||
])
|
||||
def test_poly_divmod(self, dividend, quotient, divisor, remainder):
|
||||
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
|
||||
if quotient is None:
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial .* is not a multiple of .*"):
|
||||
dividend.divmod(divisor)
|
||||
divmod(dividend, divisor)
|
||||
else:
|
||||
self.assertEqual((quotient, remainder), dividend.divmod(divisor))
|
||||
self.assertEqual((quotient, remainder), divmod(dividend, divisor))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}",
|
||||
dividend=dividend, divisor=divisor, quotient=quotient)
|
||||
for dividend, divisor, quotient in [
|
||||
(a, 1, a),
|
||||
(3 * a, 3, a),
|
||||
(3 * a + 3, 3, a + 1),
|
||||
(3 * a + 2, 3, None),
|
||||
(3 * a + 5, 3, None),
|
||||
(3 * a - 2, 3, None),
|
||||
(3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b),
|
||||
(a * a - b * b, a + b, a - b),
|
||||
(a, b, None),
|
||||
(3 * a, 2, None),
|
||||
(2 * a * b + b * b, a + b, None),
|
||||
])
|
||||
def test_poly_truediv(self, *, dividend, divisor, quotient):
|
||||
if quotient is None:
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial .* is not a multiple of .*"):
|
||||
dividend / divisor
|
||||
else:
|
||||
self.assertEqual(quotient, dividend / divisor)
|
||||
|
||||
def test_poly_truediv_error(self):
|
||||
a, = shape_poly.parse_spec("a,", (2,))
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Division of '3' by dimension polynomial .* is not supported"):
|
||||
3 / a
|
||||
|
||||
def test_dilate_shape(self):
|
||||
"""0 if d == 0 else 1 + dilation * (d - 1))"""
|
||||
@ -844,8 +875,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["(b1, b2, ...)"])(np.ones((4, 5, 6)))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
re.escape("unsupported operand type(s) for /: 'TensorFlowTracer' and '_DimPolynomial'")):
|
||||
core.InconclusiveDimensionOperation,
|
||||
re.compile("Division of .* by dimension polynomial .* is not supported",
|
||||
re.DOTALL)):
|
||||
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
|
||||
polymorphic_shapes=["(v, _)"])(np.ones((4, 4)))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user