Fix exception handling logic in C++ dispatch code.

The dispatch code was always raising its own exception when an exception
occurred during hashing of static arguments, even if the exception which
occurred was something like a KeyboardInterrupt.

fixes #9082

PiperOrigin-RevId: 420292886
This commit is contained in:
Matthew Johnson 2022-01-07 07:58:30 -08:00 committed by jax authors
parent 576630eb40
commit 3a2fb1844c

View File

@ -495,6 +495,18 @@ class CPPJitTest(jtu.BufferDonationTestCase):
re.escape("static arguments should be comparable using __eq__")):
jitted_f(1, HashableWithoutEq())
@unittest.skipIf(jax._src.lib._xla_extension_version < 50,
"requires jaxlib >= 0.1.76")
def test_cpp_jit_raises_other_exceptions_when_hashing_fails(self):
class A:
def __hash__(self):
raise ValueError
f = jax.jit(lambda x: x + 1, static_argnums=(0,))
a = A()
with self.assertRaisesRegex(ValueError, '^$'): # no extra message
f(a)
def test_cpp_jitted_function_returns_PyBuffer(self):
if self.jit != api._cpp_jit:
raise unittest.SkipTest("this test only applies to _cpp_jit")