Merge pull request #17828 from hawkinsp:tpu

PiperOrigin-RevId: 569210363
This commit is contained in:
jax authors 2023-09-28 09:47:23 -07:00
commit c490a063c8

View File

@ -62,7 +62,7 @@ all_shapes = nonempty_array_shapes + empty_array_shapes
class DLPackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.test_device_matches(["cpu", "gpu"]):
if not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
@jtu.sample_product(
@ -74,7 +74,7 @@ class DLPackTest(jtu.JaxTestCase):
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
if gpu and jax.test_device_matches(["cpu"]):
if gpu and jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)