mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15192 from mattjj:issue15190
PiperOrigin-RevId: 519037959
This commit is contained in:
commit
32e712864c
@ -725,7 +725,7 @@ _array_methods = {
|
||||
"ravel": lax_numpy.ravel,
|
||||
"repeat": lax_numpy.repeat,
|
||||
"reshape": _reshape,
|
||||
"round": round,
|
||||
"round": lax_numpy.round,
|
||||
"searchsorted": lax_numpy.searchsorted,
|
||||
"sort": lax_numpy.sort,
|
||||
"squeeze": lax_numpy.squeeze,
|
||||
|
@ -858,6 +858,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jround(jnp.array(1.234, jnp.float32)),
|
||||
check_dtypes=False)
|
||||
|
||||
def testRoundMethod(self):
|
||||
# https://github.com/google/jax/issues/15190
|
||||
(jnp.arange(3.) / 5.).round() # doesn't crash
|
||||
|
||||
@jtu.sample_product(shape=[(5,), (5, 2)])
|
||||
def testOperatorReversed(self, shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user