mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix a typo on the dynamic definition of __hash__.
This commit is contained in:
parent
a79055cb11
commit
ca72d3dc80
@ -1255,7 +1255,7 @@ for device_array in [_DeviceArray, _CppDeviceArray]:
|
||||
def __hash__(self):
|
||||
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
|
||||
|
||||
setattr(device_array, "__eq__", __hash__)
|
||||
setattr(device_array, "__hash__", __hash__)
|
||||
|
||||
# The following methods are dynamically overridden in lax_numpy.py.
|
||||
def raise_not_implemented():
|
||||
|
@ -1757,8 +1757,17 @@ class APITest(jtu.JaxTestCase):
|
||||
api.vjp(api.pmap(f), x, x)[1]((x, x))
|
||||
|
||||
def test_device_array_repr(self):
|
||||
rep = repr(jnp.ones(()) + 1.)
|
||||
self.assertStartsWith(rep, 'DeviceArray')
|
||||
rep = jnp.ones(()) + 1.
|
||||
self.assertStartsWith(repr(rep), "DeviceArray")
|
||||
|
||||
def test_device_array_hash(self):
|
||||
rep = jnp.ones(()) + 1.
|
||||
self.assertIsInstance(rep, jax.interpreters.xla._DeviceArray)
|
||||
msg = "JAX DeviceArray, like numpy.ndarray, is not hashable."
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
hash(rep)
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
hash(rep.device_buffer)
|
||||
|
||||
def test_grad_without_enough_args_error_message(self):
|
||||
# https://github.com/google/jax/issues/1696
|
||||
|
Loading…
x
Reference in New Issue
Block a user