mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic TPU] Guard tests for new features by the libtpu version
PiperOrigin-RevId: 707875450
This commit is contained in:
parent
45159494e5
commit
ad00ec1dc9
9
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
9
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -111,10 +111,17 @@ jobs:
|
||||
JAX_PLATFORMS: tpu,cpu
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic TPU does not
|
||||
# guarantee anything about forward compatibility (unless jax.export is used) and the 12
|
||||
# week compatibility window accumulates way too many failures.
|
||||
IGNORE_FLAGS=
|
||||
if [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
|
||||
IGNORE_FLAGS="--ignore=tests/pallas"
|
||||
fi
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" tests examples
|
||||
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
|
||||
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
|
@ -379,7 +379,8 @@ def is_cloud_tpu():
|
||||
# built at least `date``.
|
||||
# TODO(b/327203806): after libtpu adds a XLA version and the oldest support
|
||||
# libtpu contains the XLA version, remove using built time to skip tests.
|
||||
def if_cloud_tpu_at_least(date: datetime.date):
|
||||
def if_cloud_tpu_at_least(year: int, month: int, day: int):
|
||||
date = datetime.date(year, month, day)
|
||||
if not is_cloud_tpu():
|
||||
return True
|
||||
# The format of Cloud TPU platform_version is like:
|
||||
|
@ -1142,6 +1142,9 @@ class LaxTest(jtu.JaxTestCase):
|
||||
raise SkipTest(
|
||||
f"The dot algorithm '{algorithm}' is not supported on GPU.")
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
if algorithm not in {
|
||||
lax.DotAlgorithmPreset.DEFAULT,
|
||||
lax.DotAlgorithmPreset.BF16_BF16_F32,
|
||||
|
@ -900,6 +900,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
|
||||
|
||||
def test_compute_host_loop(self):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
@compute_on('device_host')
|
||||
@jax.jit
|
||||
def fn():
|
||||
@ -1541,6 +1544,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertArraysEqual(y_out, y1 + y1)
|
||||
|
||||
def test_compute_offload_with_linear_layout(self):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
p_sharding = jax.sharding.SingleDeviceSharding(
|
||||
jax.devices()[0], memory_kind="pinned_host"
|
||||
@ -1636,6 +1642,9 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
self.assertEqual(count(), 4)
|
||||
|
||||
def test_offload_take_host(self):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
@compute_on('device_host')
|
||||
@jax.jit
|
||||
def peer_forward(x, experts, indices, scores):
|
||||
|
@ -67,6 +67,9 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
@jax.default_matmul_precision("bfloat16")
|
||||
def test_mosaic_matmul(self):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 9, 30):
|
||||
self.skipTest("Requires libtpu built after 2024-09-30")
|
||||
dtype = jnp.float32
|
||||
def func():
|
||||
# Build the inputs here, to reduce the size of the golden inputs.
|
||||
|
@ -36,6 +36,9 @@ class ExportTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
|
||||
def test_cross_platform(self):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
def add_vectors_kernel(x_ref, y_ref, o_ref):
|
||||
x, y = x_ref[...], y_ref[...]
|
||||
o_ref[...] = x + y
|
||||
|
@ -416,6 +416,9 @@ class IndexerOpsTest(PallasBaseTest):
|
||||
case=_INDEXING_TEST_CASES,
|
||||
)
|
||||
def test_can_load_with_ref_at(self, indexer_type, case):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
if self.INTERPRET:
|
||||
self.skipTest("TODO: fails in interpret mode.")
|
||||
in_shape, indexers, out_shape = case
|
||||
|
@ -856,6 +856,9 @@ class OpsTest(PallasBaseTest):
|
||||
jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh,
|
||||
):
|
||||
self.skipTest(f"{fn.__name__} not implemented on TPU")
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built at least on 2024-12-19")
|
||||
|
||||
if (
|
||||
jtu.test_device_matches(["gpu"])
|
||||
@ -1488,6 +1491,9 @@ class OpsTest(PallasBaseTest):
|
||||
trans_y=[False, True],
|
||||
)
|
||||
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
lhs_shape, rhs_shape = lhs_and_rhs_shape
|
||||
|
||||
final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape
|
||||
@ -2076,6 +2082,9 @@ class OpsTest(PallasBaseTest):
|
||||
):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Not implemented on GPU")
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
|
||||
x = jnp.arange(np.prod(array_shapes), dtype=dtype).reshape(array_shapes)
|
||||
|
||||
|
@ -288,8 +288,12 @@ class OpsTest(PallasBaseTest):
|
||||
reduce_func = [jnp.sum, jnp.max, jnp.min]
|
||||
)
|
||||
def test_reduction(self, dtype, axis, reduce_func):
|
||||
if dtype == jnp.int32 and axis == 2:
|
||||
self.skipTest("Int32 reduction on minor is not supported.")
|
||||
if dtype == jnp.int32:
|
||||
# TODO(apaszke): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
if axis == 2:
|
||||
self.skipTest("Int32 reduction on minor is not supported.")
|
||||
# TODO(b/384127570): fix bfloat16 reduction.
|
||||
if dtype == jnp.bfloat16 and reduce_func != jnp.sum:
|
||||
self.skipTest("b/384127570")
|
||||
|
Loading…
x
Reference in New Issue
Block a user