mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix exception handling logic in C++ dispatch code.
The dispatch code was always raising its own exception when an exception occurred during hashing of static arguments, even if the exception which occurred was something like a KeyboardInterrupt. fixes #9082 PiperOrigin-RevId: 420292886
This commit is contained in:
parent
576630eb40
commit
3a2fb1844c
@ -495,6 +495,18 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
re.escape("static arguments should be comparable using __eq__")):
|
||||
jitted_f(1, HashableWithoutEq())
|
||||
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version < 50,
|
||||
"requires jaxlib >= 0.1.76")
|
||||
def test_cpp_jit_raises_other_exceptions_when_hashing_fails(self):
|
||||
class A:
|
||||
def __hash__(self):
|
||||
raise ValueError
|
||||
|
||||
f = jax.jit(lambda x: x + 1, static_argnums=(0,))
|
||||
a = A()
|
||||
with self.assertRaisesRegex(ValueError, '^$'): # no extra message
|
||||
f(a)
|
||||
|
||||
def test_cpp_jitted_function_returns_PyBuffer(self):
|
||||
if self.jit != api._cpp_jit:
|
||||
raise unittest.SkipTest("this test only applies to _cpp_jit")
|
||||
|
Loading…
x
Reference in New Issue
Block a user