diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 94d1eb987..3707efe88 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -50,11 +50,21 @@ echo "Running TPU tests..." # Run single-accelerator tests in parallel JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ ---deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ ---maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py + --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ + --maxfail=20 -m "not multiaccelerator" \ + tests/pallas/ops_test.py \ + tests/pallas/export_back_compat_pallas_test.py \ + tests/pallas/export_pallas_test.py \ + tests/pallas/tpu_ops_test.py \ + tests/pallas/tpu_pallas_test.py \ + tests/pallas/tpu_pallas_random_test.py \ + tests/pallas/tpu_pallas_async_test.py \ + tests/pallas/tpu_pallas_state_test.py # Run Pallas printing tests, which need to run with I/O capturing disabled. TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest # Run multi-accelerator across all chips -"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py \ No newline at end of file +"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \ + tests/pjit_test.py \ + tests/pallas/tpu_pallas_distributed_test.py \ No newline at end of file diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index a2f137686..826d14b5b 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import os +import re from jax import version from jax._src import config from jax._src import hardware_utils @@ -98,3 +100,22 @@ def cloud_tpu_init() -> None: 'jax_pjrt_client_create_options', f'ml_framework_name:JAX;ml_framework_version:{version.__version__}' ) + + +def is_cloud_tpu_older_than(year: int, month: int, day: int): + # We import locally because the functions above must run before the runtime + # modules are imported. + from jax._src import xla_bridge # type: ignore + date = datetime.date(year, month, day) + if not running_in_cloud_tpu_vm: + return False + # The format of Cloud TPU platform_version is like: + # PJRT C API + # TFRT TPU v2 + # Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722 + platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1] + results = re.findall(r'\(.*?\)', platform_version) + if len(results) != 1: + return True + build_date = date.fromtimestamp(int(results[0][1:-1])) + return build_date < date diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1e1f19d3d..f0fb23b1f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -39,6 +39,7 @@ from jax._src import prng from jax._src import source_info_util from jax._src import state from jax._src import traceback_util +from jax._src.cloud_tpu_init import is_cloud_tpu_older_than from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal @@ -597,6 +598,11 @@ def lower_jaxpr_to_module( for_verification: bool = False, dynamic_shape_replacement_enabled: bool = False, ) -> tuple[Module, tuple[Any, ...]]: + # NOTE: We should bump this periodically + if is_cloud_tpu_older_than(2025, 1, 10): + raise RuntimeError( + "Pallas TPU requires a libTPU version that's at most a month old" + ) if dynamic_shape_replacement_enabled: _mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv() @@ -3399,12 +3405,15 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, src, _ = _transform_ref(src, src_aval.dtype, src_aval.shape, src_transforms) dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) - if ctx.forward_compatible: - # TODO(mvoz): Remove once a month has passed. b/395630795 - src_memory_space = _memory_space_to_mosaic_attribute(src_aval.memory_space) - smem_space = ir.Attribute.parse("#tpu.memory_space") - src_is_smem = src_memory_space == smem_space - wait_ref = src if src_is_smem else dst + if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 2, 12): + # TODO(mvoz): Remove once six months have passed. b/395630795 + if hasattr(src_aval, "memory_space"): + src_memory_space = _memory_space_to_mosaic_attribute(src_aval.memory_space) + smem_space = ir.Attribute.parse("#tpu.memory_space") + src_is_smem = src_memory_space == smem_space + wait_ref = src if src_is_smem else dst + else: + wait_ref = dst # Legacy instruction emits only an sfence if the target/dst ref is in smem. # So, we pass the src ref to the wait instruction if it is in smem to # ensure legacy cases are correct, while technically keeping API compat.