only test if omnistaging is enabled

This commit is contained in:
Matthew Johnson 2020-09-21 19:33:14 -07:00
parent 6b60267e55
commit 34e0460961

View File

@ -1798,6 +1798,9 @@ class APITest(jtu.JaxTestCase):
f() # doesn't crash
def test_xla_computation_zeros_doesnt_device_put(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
count = 0
def device_put_and_count(*args, **kwargs):
nonlocal count