mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX] Include the function name in the error message when hash/equality of a static argument fails.
PiperOrigin-RevId: 395550728
This commit is contained in:
parent
b0e0e46109
commit
7f277068a9
@ -439,10 +439,11 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
jitted_f(1, 1)
|
||||
|
||||
msg = ("Non-hashable static arguments are not supported. An error occured "
|
||||
"while trying to hash an object of type <class 'numpy.ndarray'>, 1. "
|
||||
"The error was:\nTypeError: unhashable type: 'numpy.ndarray'")
|
||||
".*while trying to hash an object of type "
|
||||
"<class 'numpy\\.ndarray'>, 1. The error was:\nTypeError: "
|
||||
"unhashable type: 'numpy\\.ndarray'")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, re.escape(msg)):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jitted_f(1, np.asarray(1))
|
||||
|
||||
class HashableWithoutEq:
|
||||
|
Loading…
x
Reference in New Issue
Block a user