mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
only test if omnistaging is enabled
This commit is contained in:
parent
6b60267e55
commit
34e0460961
@ -1798,6 +1798,9 @@ class APITest(jtu.JaxTestCase):
|
||||
f() # doesn't crash
|
||||
|
||||
def test_xla_computation_zeros_doesnt_device_put(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
count = 0
|
||||
def device_put_and_count(*args, **kwargs):
|
||||
nonlocal count
|
||||
|
Loading…
x
Reference in New Issue
Block a user