mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix incorrect backend allowlist in array_interoperability_test.
We intended to only enable this test on CPU and GPU, but we were missing a critical "not".
This commit is contained in:
parent
173a270179
commit
d0baa1d11b
@ -62,7 +62,7 @@ all_shapes = nonempty_array_shapes + empty_array_shapes
|
||||
class DLPackTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.test_device_matches(["cpu", "gpu"]):
|
||||
if not jtu.test_device_matches(["cpu", "gpu"]):
|
||||
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -74,7 +74,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
if gpu and jax.test_device_matches(["cpu"]):
|
||||
if gpu and jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("Skipping GPU test case on CPU")
|
||||
device = jax.devices("gpu" if gpu else "cpu")[0]
|
||||
x = jax.device_put(np, device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user