mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00
[PJRT C API] Bump the minimum support libtpu version because there is a breaking change (075d25e0c1
).
Also remove skip condition that are no longer needed because of this bump. PiperOrigin-RevId: 611288492
This commit is contained in:
parent
550ce44afd
commit
4c57d09590
2
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
2
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
tpu-type: ["v3-8", "v4-8", "v5e-4"]
|
||||
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
|
||||
env:
|
||||
LIBTPU_OLDEST_VERSION_DATE: 20231030
|
||||
LIBTPU_OLDEST_VERSION_DATE: 20240228
|
||||
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
|
||||
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
|
||||
timeout-minutes: 120
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import math
|
||||
from absl.testing import absltest
|
||||
@ -1080,8 +1079,6 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
|
||||
self.skipTest("Memories do not work on Cloud TPU older than 2024/02/23.")
|
||||
super().setUp()
|
||||
|
||||
def test_remat_jaxpr_offloadable(self):
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
"""Test TPU-specific extensions to pallas_call."""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -49,8 +48,6 @@ class PallasTPUTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if not self.interpret and jtu.device_under_test() != 'tpu':
|
||||
self.skipTest('Only interpret mode supported on non-TPU')
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 10)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/10.')
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
|
||||
@ -346,8 +343,6 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
|
||||
# TODO(apaszke): Add tests for scalar_prefetch too
|
||||
def test_dynamic_grid_scalar_input(self):
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 14)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/14.')
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
@ -441,9 +436,6 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
)
|
||||
|
||||
def test_num_programs(self):
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
|
||||
|
||||
def kernel(y_ref):
|
||||
y_ref[0, 0] = pl.num_programs(0)
|
||||
|
||||
@ -459,9 +451,6 @@ class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
self.assertEqual(dynamic_kernel(4), 8)
|
||||
|
||||
def test_num_programs_block_spec(self):
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
|
||||
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = x_ref[...]
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import jax
|
||||
@ -67,8 +66,6 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if xla_extension_version < 227:
|
||||
self.skipTest('Requires xla_extension_version >= 227')
|
||||
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
|
||||
self.skipTest("Requires Cloud TPU older than 2024/02/23.")
|
||||
|
||||
def test_basic(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user