mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #12977 from yejingxin:main
PiperOrigin-RevId: 483812465
This commit is contained in:
commit
0d1e230d97
@ -255,6 +255,9 @@ def is_device_rocm():
|
||||
def is_device_cuda():
|
||||
return xla_bridge.get_backend().platform_version.startswith('cuda')
|
||||
|
||||
def is_cloud_tpu():
|
||||
return 'libtpu' in xla_bridge.get_backend().platform_version
|
||||
|
||||
def is_device_tpu_v4():
|
||||
return jax.devices()[0].device_kind == "TPU v4"
|
||||
|
||||
|
@ -1164,6 +1164,9 @@ class InspectShardingTest(jtu.JaxTestCase):
|
||||
if jaxlib.xla_extension_version < 94:
|
||||
raise unittest.SkipTest("Inspect sharding not supported.")
|
||||
|
||||
if jtu.is_cloud_tpu():
|
||||
raise unittest.SkipTest("Inspect sharding is not supported on libtpu.")
|
||||
|
||||
is_called = False
|
||||
def _cb(sd):
|
||||
nonlocal is_called
|
||||
|
@ -1096,6 +1096,9 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
if xla_extension_version < 95:
|
||||
raise unittest.SkipTest('Must support custom partitioning.')
|
||||
|
||||
if jtu.is_cloud_tpu():
|
||||
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
|
||||
|
||||
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
|
||||
self.assertEqual(arg_shardings[0], result_sharding)
|
||||
self.assertEqual(P(('x',)), result_sharding.spec)
|
||||
|
Loading…
x
Reference in New Issue
Block a user