jax_jit_test: inherit from JaxTestCase

This commit is contained in:
Jake VanderPlas 2022-01-24 15:32:59 -08:00
parent b8372b0ca2
commit 8cacad2e23

View File

@ -39,7 +39,7 @@ def _cpp_device_put(value, device):
return jaxlib.jax_jit.device_put(value, config.x64_enabled, device)
class JaxJitTest(parameterized.TestCase):
class JaxJitTest(jtu.JaxTestCase):
def test_is_float_0(self):
self.assertTrue(