Add regression test for #4780

This commit is contained in:
Jake VanderPlas 2023-03-16 09:05:23 -07:00
parent 860630b367
commit f6bedb13f7

View File

@ -1199,6 +1199,13 @@ class CPPJitTest(jtu.BufferDonationTestCase):
_ = f(x)
self.assertEqual(f._cache_size(), 1)
def test_jit_nan_times_zero(self):
# https://github.com/google/jax/issues/4780
def f(x):
return 1 + x * 0
self.assertAllClose(f(np.nan), np.nan)
self.assertAllClose(self.jit(f)(np.nan), np.nan)
class PythonJitTest(CPPJitTest):