mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add regression test for #4780
This commit is contained in:
parent
860630b367
commit
f6bedb13f7
@ -1199,6 +1199,13 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
_ = f(x)
|
||||
self.assertEqual(f._cache_size(), 1)
|
||||
|
||||
def test_jit_nan_times_zero(self):
|
||||
# https://github.com/google/jax/issues/4780
|
||||
def f(x):
|
||||
return 1 + x * 0
|
||||
self.assertAllClose(f(np.nan), np.nan)
|
||||
self.assertAllClose(self.jit(f)(np.nan), np.nan)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user