[PJRT C API] Bump the minimum support libtpu version because there is a breaking change (075d25e0c1).

Also remove skip condition that are no longer needed because of this bump.

PiperOrigin-RevId: 611288492
This commit is contained in:
Jieying Luo 2024-02-28 17:50:37 -08:00 committed by jax authors
parent 550ce44afd
commit 4c57d09590
4 changed files with 1 additions and 18 deletions

View File

@ -28,7 +28,7 @@ jobs:
tpu-type: ["v3-8", "v4-8", "v5e-4"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20231030
LIBTPU_OLDEST_VERSION_DATE: 20240228
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
timeout-minutes: 120

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import functools
import math
from absl.testing import absltest
@ -1080,8 +1079,6 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
self.skipTest("Memories do not work on Cloud TPU older than 2024/02/23.")
super().setUp()
def test_remat_jaxpr_offloadable(self):

View File

@ -14,7 +14,6 @@
"""Test TPU-specific extensions to pallas_call."""
import datetime
import functools
from absl.testing import absltest
from absl.testing import parameterized
@ -49,8 +48,6 @@ class PallasTPUTest(jtu.JaxTestCase):
super().setUp()
if not self.interpret and jtu.device_under_test() != 'tpu':
self.skipTest('Only interpret mode supported on non-TPU')
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 10)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/10.')
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
@ -346,8 +343,6 @@ class PallasCallDynamicGridTest(PallasTPUTest):
# TODO(apaszke): Add tests for scalar_prefetch too
def test_dynamic_grid_scalar_input(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 14)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/14.')
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
@ -441,9 +436,6 @@ class PallasCallDynamicGridTest(PallasTPUTest):
)
def test_num_programs(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
def kernel(y_ref):
y_ref[0, 0] = pl.num_programs(0)
@ -459,9 +451,6 @@ class PallasCallDynamicGridTest(PallasTPUTest):
self.assertEqual(dynamic_kernel(4), 8)
def test_num_programs_block_spec(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import jax
@ -67,8 +66,6 @@ class ShardAlikeTest(jtu.JaxTestCase):
super().setUp()
if xla_extension_version < 227:
self.skipTest('Requires xla_extension_version >= 227')
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
self.skipTest("Requires Cloud TPU older than 2024/02/23.")
def test_basic(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))