diff --git a/tests/api_test.py b/tests/api_test.py index 727efe830..1ba31e7bf 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -757,17 +757,13 @@ class CPPJitTest(jtu.BufferDonationTestCase): self.assertAllClose(result_2, x + 3) self.assertAllClose(result_cpu_2, x + 4) - @parameterized.named_parameters( - ('jit', jax.jit), - ('pjit', pjit.pjit) - ) @jtu.skip_on_devices("cpu") - def test_mismatched_nested_backends(self, module): - @partial(module, backend=jtu.device_under_test()) + def test_mismatched_nested_backends(self): + @partial(jax.jit, backend=jtu.device_under_test()) def f(x): - return module(lambda x: x + 1, backend="cpu")(x) + return jax.jit(lambda x: x + 1, backend="cpu")(x) - if module is pjit.pjit: + if jax.config.jax_jit_pjit_api_merge: msg = 'Devices of all `Array` inputs and outputs should be the same' else: msg = ("Outer-jit backend specification .* must match explicit inner-jit "