mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
015bab4133
commit
89c8046672
@ -483,6 +483,29 @@ class CPPJitTest(jtu.JaxTestCase):
|
||||
jitted_f(1)
|
||||
self.assertIsInstance(jitted_f(2), xla._CppDeviceArray)
|
||||
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_explicit_backend(self):
|
||||
f = lambda x: x + 1
|
||||
jitted_f = jit(f, backend=jtu.device_under_test())
|
||||
jitted_f_cpu = jit(f, backend="cpu")
|
||||
|
||||
result = jitted_f(1.)
|
||||
result_cpu = jitted_f_cpu(1.)
|
||||
self.assertEqual(result.device_buffer.platform(), jtu.device_under_test())
|
||||
self.assertEqual(result_cpu.device_buffer.platform(), "cpu")
|
||||
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_mismatched_nested_backends(self):
|
||||
@partial(jit, backend=jtu.device_under_test())
|
||||
def f(x):
|
||||
return jit(lambda x: x + 1, backend="cpu")(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f"Outer-jit backend specification {jtu.device_under_test()} must match "
|
||||
f"explicit inner-jit backend specification cpu."):
|
||||
f(1.)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user