Fix incorrect backend allowlist in array_interoperability_test.

We intended to only enable this test on CPU and GPU, but we were missing a critical "not".
This commit is contained in:
Peter Hawkins 2023-09-28 10:30:22 -04:00
parent 173a270179
commit d0baa1d11b

View File

@ -62,7 +62,7 @@ all_shapes = nonempty_array_shapes + empty_array_shapes
class DLPackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.test_device_matches(["cpu", "gpu"]):
if not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
@jtu.sample_product(
@ -74,7 +74,7 @@ class DLPackTest(jtu.JaxTestCase):
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
if gpu and jax.test_device_matches(["cpu"]):
if gpu and jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)