Merge pull request #12977 from yejingxin:main

PiperOrigin-RevId: 483812465
This commit is contained in:
jax authors 2022-10-25 16:58:14 -07:00
commit 0d1e230d97
3 changed files with 9 additions and 0 deletions

View File

@ -255,6 +255,9 @@ def is_device_rocm():
def is_device_cuda():
return xla_bridge.get_backend().platform_version.startswith('cuda')
def is_cloud_tpu():
return 'libtpu' in xla_bridge.get_backend().platform_version
def is_device_tpu_v4():
return jax.devices()[0].device_kind == "TPU v4"

View File

@ -1164,6 +1164,9 @@ class InspectShardingTest(jtu.JaxTestCase):
if jaxlib.xla_extension_version < 94:
raise unittest.SkipTest("Inspect sharding not supported.")
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Inspect sharding is not supported on libtpu.")
is_called = False
def _cb(sd):
nonlocal is_called

View File

@ -1096,6 +1096,9 @@ class PJitTest(jtu.BufferDonationTestCase):
if xla_extension_version < 95:
raise unittest.SkipTest('Must support custom partitioning.')
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
self.assertEqual(arg_shardings[0], result_sharding)
self.assertEqual(P(('x',)), result_sharding.spec)