[JAX] Include the function name in the error message when hash/equality of a static argument fails.

PiperOrigin-RevId: 395550728
This commit is contained in:
Peter Hawkins 2021-09-08 13:50:08 -07:00 committed by jax authors
parent b0e0e46109
commit 7f277068a9

View File

@ -439,10 +439,11 @@ class CPPJitTest(jtu.BufferDonationTestCase):
jitted_f(1, 1)
msg = ("Non-hashable static arguments are not supported. An error occured "
"while trying to hash an object of type <class 'numpy.ndarray'>, 1. "
"The error was:\nTypeError: unhashable type: 'numpy.ndarray'")
".*while trying to hash an object of type "
"<class 'numpy\\.ndarray'>, 1. The error was:\nTypeError: "
"unhashable type: 'numpy\\.ndarray'")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
with self.assertRaisesRegex(ValueError, msg):
jitted_f(1, np.asarray(1))
class HashableWithoutEq: