[PJRT C API] Add a helper method to check whether the backend is cloud TPU built after certain date.

Skip tests that are not intended to work with older version libtpu.

PiperOrigin-RevId: 610892754
This commit is contained in:
Jieying Luo 2024-02-27 15:24:12 -08:00 committed by jax authors
parent fdbee314d3
commit 3dbbfefef8
4 changed files with 36 additions and 0 deletions

View File

@ -15,6 +15,7 @@ from __future__ import annotations
from collections.abc import Generator, Iterable, Sequence
from contextlib import contextmanager, ExitStack
import datetime
import inspect
import io
import functools
@ -366,6 +367,24 @@ def is_device_cuda():
def is_cloud_tpu():
return running_in_cloud_tpu_vm
# Returns True if it is not cloud TPU. If it is cloud TPU, returns True if it is
# built at least `date``.
# TODO(b/327203806): after libtpu adds a XLA version and the oldest support
# libtpu contains the XLA version, remove using built time to skip tests.
def if_cloud_tpu_at_least(date: datetime.date):
if not is_cloud_tpu():
return True
# The format of Cloud TPU platform_version is like:
# PJRT C API
# TFRT TPU v2
# Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722
platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1]
results = re.findall(r'\(.*?\)', platform_version)
if len(results) != 1:
return True
build_date = date.fromtimestamp(int(results[0][1:-1]))
return build_date >= date
def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
pjrt_c_api_versions = xla_bridge.backend_pjrt_c_api_version()
if pjrt_c_api_versions is None:

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import functools
import math
from absl.testing import absltest
@ -1079,6 +1080,8 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
self.skipTest("Memories do not work on Cloud TPU older than 2024/02/23.")
super().setUp()
def test_remat_jaxpr_offloadable(self):

View File

@ -14,6 +14,7 @@
"""Test TPU-specific extensions to pallas_call."""
import datetime
import functools
from absl.testing import absltest
from absl.testing import parameterized
@ -48,6 +49,8 @@ class PallasTPUTest(jtu.JaxTestCase):
super().setUp()
if not self.interpret and jtu.device_under_test() != 'tpu':
self.skipTest('Only interpret mode supported on non-TPU')
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 10)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/10.')
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
@ -343,6 +346,8 @@ class PallasCallDynamicGridTest(PallasTPUTest):
# TODO(apaszke): Add tests for scalar_prefetch too
def test_dynamic_grid_scalar_input(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 14)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/14.')
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
@ -436,6 +441,9 @@ class PallasCallDynamicGridTest(PallasTPUTest):
)
def test_num_programs(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
def kernel(y_ref):
y_ref[0, 0] = pl.num_programs(0)
@ -451,6 +459,9 @@ class PallasCallDynamicGridTest(PallasTPUTest):
self.assertEqual(dynamic_kernel(4), 8)
def test_num_programs_block_spec(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import jax
@ -66,6 +67,8 @@ class ShardAlikeTest(jtu.JaxTestCase):
super().setUp()
if xla_extension_version < 227:
self.skipTest('Requires xla_extension_version >= 227')
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
self.skipTest("Requires Cloud TPU older than 2024/02/23.")
def test_basic(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))