[Mosaic TPU] Guard tests for new features by the libtpu version

PiperOrigin-RevId: 707875450
This commit is contained in:
Adam Paszke 2024-12-19 05:03:34 -08:00 committed by jax authors
parent 45159494e5
commit ad00ec1dc9
9 changed files with 46 additions and 4 deletions

View File

@ -111,10 +111,17 @@ jobs:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
run: |
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic TPU does not
# guarantee anything about forward compatibility (unless jax.export is used) and the 12
# week compatibility window accumulates way too many failures.
IGNORE_FLAGS=
if [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
IGNORE_FLAGS="--ignore=tests/pallas"
fi
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest

View File

@ -379,7 +379,8 @@ def is_cloud_tpu():
# built at least `date``.
# TODO(b/327203806): after libtpu adds a XLA version and the oldest support
# libtpu contains the XLA version, remove using built time to skip tests.
def if_cloud_tpu_at_least(date: datetime.date):
def if_cloud_tpu_at_least(year: int, month: int, day: int):
date = datetime.date(year, month, day)
if not is_cloud_tpu():
return True
# The format of Cloud TPU platform_version is like:

View File

@ -1142,6 +1142,9 @@ class LaxTest(jtu.JaxTestCase):
raise SkipTest(
f"The dot algorithm '{algorithm}' is not supported on GPU.")
if jtu.test_device_matches(["tpu"]):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
if algorithm not in {
lax.DotAlgorithmPreset.DEFAULT,
lax.DotAlgorithmPreset.BF16_BF16_F32,

View File

@ -900,6 +900,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
def test_compute_host_loop(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
@compute_on('device_host')
@jax.jit
def fn():
@ -1541,6 +1544,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
self.assertArraysEqual(y_out, y1 + y1)
def test_compute_offload_with_linear_layout(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
p_sharding = jax.sharding.SingleDeviceSharding(
jax.devices()[0], memory_kind="pinned_host"
@ -1636,6 +1642,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
self.assertEqual(count(), 4)
def test_offload_take_host(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
@compute_on('device_host')
@jax.jit
def peer_forward(x, experts, indices, scores):

View File

@ -67,6 +67,9 @@ class CompatTest(bctu.CompatTestBase):
@jax.default_matmul_precision("bfloat16")
def test_mosaic_matmul(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 9, 30):
self.skipTest("Requires libtpu built after 2024-09-30")
dtype = jnp.float32
def func():
# Build the inputs here, to reduce the size of the golden inputs.

View File

@ -36,6 +36,9 @@ class ExportTest(jtu.JaxTestCase):
super().setUp()
def test_cross_platform(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y

View File

@ -416,6 +416,9 @@ class IndexerOpsTest(PallasBaseTest):
case=_INDEXING_TEST_CASES,
)
def test_can_load_with_ref_at(self, indexer_type, case):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
if self.INTERPRET:
self.skipTest("TODO: fails in interpret mode.")
in_shape, indexers, out_shape = case

View File

@ -856,6 +856,9 @@ class OpsTest(PallasBaseTest):
jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh,
):
self.skipTest(f"{fn.__name__} not implemented on TPU")
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built at least on 2024-12-19")
if (
jtu.test_device_matches(["gpu"])
@ -1488,6 +1491,9 @@ class OpsTest(PallasBaseTest):
trans_y=[False, True],
)
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
lhs_shape, rhs_shape = lhs_and_rhs_shape
final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape
@ -2076,6 +2082,9 @@ class OpsTest(PallasBaseTest):
):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Not implemented on GPU")
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
x = jnp.arange(np.prod(array_shapes), dtype=dtype).reshape(array_shapes)

View File

@ -288,8 +288,12 @@ class OpsTest(PallasBaseTest):
reduce_func = [jnp.sum, jnp.max, jnp.min]
)
def test_reduction(self, dtype, axis, reduce_func):
if dtype == jnp.int32 and axis == 2:
self.skipTest("Int32 reduction on minor is not supported.")
if dtype == jnp.int32:
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
if axis == 2:
self.skipTest("Int32 reduction on minor is not supported.")
# TODO(b/384127570): fix bfloat16 reduction.
if dtype == jnp.bfloat16 and reduce_func != jnp.sum:
self.skipTest("b/384127570")