mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #17828 from hawkinsp:tpu
PiperOrigin-RevId: 569210363
This commit is contained in:
commit
c490a063c8
@ -62,7 +62,7 @@ all_shapes = nonempty_array_shapes + empty_array_shapes
|
||||
class DLPackTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.test_device_matches(["cpu", "gpu"]):
|
||||
if not jtu.test_device_matches(["cpu", "gpu"]):
|
||||
self.skipTest(f"DLPack not supported on {jtu.device_under_test()}")
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -74,7 +74,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
if gpu and jax.test_device_matches(["cpu"]):
|
||||
if gpu and jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("Skipping GPU test case on CPU")
|
||||
device = jax.devices("gpu" if gpu else "cpu")[0]
|
||||
x = jax.device_put(np, device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user