Merge pull request #15192 from mattjj:issue15190

PiperOrigin-RevId: 519037959
This commit is contained in:
jax authors 2023-03-23 20:48:33 -07:00
commit 32e712864c
2 changed files with 5 additions and 1 deletions

View File

@ -725,7 +725,7 @@ _array_methods = {
"ravel": lax_numpy.ravel,
"repeat": lax_numpy.repeat,
"reshape": _reshape,
"round": round,
"round": lax_numpy.round,
"searchsorted": lax_numpy.searchsorted,
"sort": lax_numpy.sort,
"squeeze": lax_numpy.squeeze,

View File

@ -858,6 +858,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jround(jnp.array(1.234, jnp.float32)),
check_dtypes=False)
def testRoundMethod(self):
# https://github.com/google/jax/issues/15190
(jnp.arange(3.) / 5.).round() # doesn't crash
@jtu.sample_product(shape=[(5,), (5, 2)])
def testOperatorReversed(self, shape):
rng = jtu.rand_default(self.rng())