mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Don't run test_mismatched_nested_backends
test with pjit and jit because jax_jit_pjit_api_merge
will do that for us.
PiperOrigin-RevId: 504168144
This commit is contained in:
parent
fb9b5ec1e4
commit
1641c8f141
@ -757,17 +757,13 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(result_2, x + 3)
|
||||
self.assertAllClose(result_cpu_2, x + 4)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit', jax.jit),
|
||||
('pjit', pjit.pjit)
|
||||
)
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_mismatched_nested_backends(self, module):
|
||||
@partial(module, backend=jtu.device_under_test())
|
||||
def test_mismatched_nested_backends(self):
|
||||
@partial(jax.jit, backend=jtu.device_under_test())
|
||||
def f(x):
|
||||
return module(lambda x: x + 1, backend="cpu")(x)
|
||||
return jax.jit(lambda x: x + 1, backend="cpu")(x)
|
||||
|
||||
if module is pjit.pjit:
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
msg = 'Devices of all `Array` inputs and outputs should be the same'
|
||||
else:
|
||||
msg = ("Outer-jit backend specification .* must match explicit inner-jit "
|
||||
|
Loading…
x
Reference in New Issue
Block a user