Make sure we take libTPU version into account in the Pallas lowering

Also, strengthen the presubmit to make sure we catch more errors.

PiperOrigin-RevId: 726061633
This commit is contained in:
Adam Paszke 2025-02-12 08:15:15 -08:00 committed by jax authors
parent e14466a8fb
commit f1ab7514db
3 changed files with 49 additions and 9 deletions

View File

@ -50,11 +50,21 @@ echo "Running TPU tests..."
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" \
tests/pallas/ops_test.py \
tests/pallas/export_back_compat_pallas_test.py \
tests/pallas/export_pallas_test.py \
tests/pallas/tpu_ops_test.py \
tests/pallas/tpu_pallas_test.py \
tests/pallas/tpu_pallas_random_test.py \
tests/pallas/tpu_pallas_async_test.py \
tests/pallas/tpu_pallas_state_test.py
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
tests/pjit_test.py \
tests/pallas/tpu_pallas_distributed_test.py

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
import re
from jax import version
from jax._src import config
from jax._src import hardware_utils
@ -98,3 +100,22 @@ def cloud_tpu_init() -> None:
'jax_pjrt_client_create_options',
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
)
def is_cloud_tpu_older_than(year: int, month: int, day: int):
# We import locally because the functions above must run before the runtime
# modules are imported.
from jax._src import xla_bridge # type: ignore
date = datetime.date(year, month, day)
if not running_in_cloud_tpu_vm:
return False
# The format of Cloud TPU platform_version is like:
# PJRT C API
# TFRT TPU v2
# Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722
platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1]
results = re.findall(r'\(.*?\)', platform_version)
if len(results) != 1:
return True
build_date = date.fromtimestamp(int(results[0][1:-1]))
return build_date < date

View File

@ -39,6 +39,7 @@ from jax._src import prng
from jax._src import source_info_util
from jax._src import state
from jax._src import traceback_util
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
@ -597,6 +598,11 @@ def lower_jaxpr_to_module(
for_verification: bool = False,
dynamic_shape_replacement_enabled: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
# NOTE: We should bump this periodically
if is_cloud_tpu_older_than(2025, 1, 10):
raise RuntimeError(
"Pallas TPU requires a libTPU version that's at most a month old"
)
if dynamic_shape_replacement_enabled:
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
@ -3399,12 +3405,15 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
src, _ = _transform_ref(src, src_aval.dtype, src_aval.shape, src_transforms)
dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms)
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
if ctx.forward_compatible:
# TODO(mvoz): Remove once a month has passed. b/395630795
if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 2, 12):
# TODO(mvoz): Remove once six months have passed. b/395630795
if hasattr(src_aval, "memory_space"):
src_memory_space = _memory_space_to_mosaic_attribute(src_aval.memory_space)
smem_space = ir.Attribute.parse("#tpu.memory_space<smem>")
src_is_smem = src_memory_space == smem_space
wait_ref = src if src_is_smem else dst
else:
wait_ref = dst
# Legacy instruction emits only an sfence if the target/dst ref is in smem.
# So, we pass the src ref to the wait instruction if it is in smem to
# ensure legacy cases are correct, while technically keeping API compat.