diff --git a/tests/api_test.py b/tests/api_test.py index e8bb91235..a2b28d63b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1798,6 +1798,9 @@ class APITest(jtu.JaxTestCase): f() # doesn't crash def test_xla_computation_zeros_doesnt_device_put(self): + if not config.omnistaging_enabled: + raise unittest.SkipTest("test is omnistaging-specific") + count = 0 def device_put_and_count(*args, **kwargs): nonlocal count