mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PJRT C API] Add a helper method to check whether the backend is cloud TPU built after certain date.
Skip tests that are not intended to work with older version libtpu. PiperOrigin-RevId: 610892754
This commit is contained in:
parent
fdbee314d3
commit
3dbbfefef8
@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from contextlib import contextmanager, ExitStack
|
||||
import datetime
|
||||
import inspect
|
||||
import io
|
||||
import functools
|
||||
@ -366,6 +367,24 @@ def is_device_cuda():
|
||||
def is_cloud_tpu():
|
||||
return running_in_cloud_tpu_vm
|
||||
|
||||
# Returns True if it is not cloud TPU. If it is cloud TPU, returns True if it is
|
||||
# built at least `date``.
|
||||
# TODO(b/327203806): after libtpu adds a XLA version and the oldest support
|
||||
# libtpu contains the XLA version, remove using built time to skip tests.
|
||||
def if_cloud_tpu_at_least(date: datetime.date):
|
||||
if not is_cloud_tpu():
|
||||
return True
|
||||
# The format of Cloud TPU platform_version is like:
|
||||
# PJRT C API
|
||||
# TFRT TPU v2
|
||||
# Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722
|
||||
platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1]
|
||||
results = re.findall(r'\(.*?\)', platform_version)
|
||||
if len(results) != 1:
|
||||
return True
|
||||
build_date = date.fromtimestamp(int(results[0][1:-1]))
|
||||
return build_date >= date
|
||||
|
||||
def pjrt_c_api_version_at_least(major_version: int, minor_version: int):
|
||||
pjrt_c_api_versions = xla_bridge.backend_pjrt_c_api_version()
|
||||
if pjrt_c_api_versions is None:
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import math
|
||||
from absl.testing import absltest
|
||||
@ -1079,6 +1080,8 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
|
||||
self.skipTest("Memories do not work on Cloud TPU older than 2024/02/23.")
|
||||
super().setUp()
|
||||
|
||||
def test_remat_jaxpr_offloadable(self):
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
"""Test TPU-specific extensions to pallas_call."""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -48,6 +49,8 @@ class PallasTPUTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if not self.interpret and jtu.device_under_test() != 'tpu':
|
||||
self.skipTest('Only interpret mode supported on non-TPU')
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 10)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/10.')
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
|
||||
@ -343,6 +346,8 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
|
||||
# TODO(apaszke): Add tests for scalar_prefetch too
|
||||
def test_dynamic_grid_scalar_input(self):
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 14)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/14.')
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
@ -436,6 +441,9 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
)
|
||||
|
||||
def test_num_programs(self):
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
|
||||
|
||||
def kernel(y_ref):
|
||||
y_ref[0, 0] = pl.num_programs(0)
|
||||
|
||||
@ -451,6 +459,9 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
self.assertEqual(dynamic_kernel(4), 8)
|
||||
|
||||
def test_num_programs_block_spec(self):
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = x_ref[...]
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import jax
|
||||
@ -66,6 +67,8 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if xla_extension_version < 227:
|
||||
self.skipTest('Requires xla_extension_version >= 227')
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
|
||||
self.skipTest("Requires Cloud TPU older than 2024/02/23.")
|
||||
|
||||
def test_basic(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user