Fix a typo on the dynamic definition of __hash__.

This commit is contained in:
Jean-Baptiste Lespiau 2020-12-13 03:17:32 +01:00 committed by Jean-Baptiste
parent a79055cb11
commit ca72d3dc80
2 changed files with 12 additions and 3 deletions

View File

@ -1255,7 +1255,7 @@ for device_array in [_DeviceArray, _CppDeviceArray]:
def __hash__(self):
raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.")
setattr(device_array, "__eq__", __hash__)
setattr(device_array, "__hash__", __hash__)
# The following methods are dynamically overridden in lax_numpy.py.
def raise_not_implemented():

View File

@ -1757,8 +1757,17 @@ class APITest(jtu.JaxTestCase):
api.vjp(api.pmap(f), x, x)[1]((x, x))
def test_device_array_repr(self):
rep = repr(jnp.ones(()) + 1.)
self.assertStartsWith(rep, 'DeviceArray')
rep = jnp.ones(()) + 1.
self.assertStartsWith(repr(rep), "DeviceArray")
def test_device_array_hash(self):
rep = jnp.ones(()) + 1.
self.assertIsInstance(rep, jax.interpreters.xla._DeviceArray)
msg = "JAX DeviceArray, like numpy.ndarray, is not hashable."
with self.assertRaisesRegex(TypeError, msg):
hash(rep)
with self.assertRaisesRegex(TypeError, msg):
hash(rep.device_buffer)
def test_grad_without_enough_args_error_message(self):
# https://github.com/google/jax/issues/1696