Add assert raises check.

This commit is contained in:
zafarali 2021-07-21 16:12:52 -04:00
parent 8143773b79
commit 00a273d1fe

View File

@ -2077,6 +2077,8 @@ class APITest(jtu.JaxTestCase):
rep = jnp.ones(()) + 1.
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
self.assertNotIsInstance(rep, collections.Hashable)
with self.assertRaisesRegex(TypeError, 'unhashable type'):
hash(rep)
def test_grad_without_enough_args_error_message(self):
# https://github.com/google/jax/issues/1696