Add unit tests for jit backend argument

Inspired by #5188
This commit is contained in:
Skye Wanderman-Milne 2020-12-21 10:39:59 -08:00
parent 015bab4133
commit 89c8046672

View File

@ -483,6 +483,29 @@ class CPPJitTest(jtu.JaxTestCase):
jitted_f(1)
self.assertIsInstance(jitted_f(2), xla._CppDeviceArray)
@jtu.skip_on_devices("cpu")
def test_explicit_backend(self):
f = lambda x: x + 1
jitted_f = jit(f, backend=jtu.device_under_test())
jitted_f_cpu = jit(f, backend="cpu")
result = jitted_f(1.)
result_cpu = jitted_f_cpu(1.)
self.assertEqual(result.device_buffer.platform(), jtu.device_under_test())
self.assertEqual(result_cpu.device_buffer.platform(), "cpu")
@jtu.skip_on_devices("cpu")
def test_mismatched_nested_backends(self):
@partial(jit, backend=jtu.device_under_test())
def f(x):
return jit(lambda x: x + 1, backend="cpu")(x)
with self.assertRaisesRegex(
ValueError,
f"Outer-jit backend specification {jtu.device_under_test()} must match "
f"explicit inner-jit backend specification cpu."):
f(1.)
class PythonJitTest(CPPJitTest):