mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make sure we take libTPU version into account in the Pallas lowering
Also, strengthen the presubmit to make sure we catch more errors. PiperOrigin-RevId: 726061633
This commit is contained in:
parent
e14466a8fb
commit
f1ab7514db
@ -50,11 +50,21 @@ echo "Running TPU tests..."
|
||||
|
||||
# Run single-accelerator tests in parallel
|
||||
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py
|
||||
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
|
||||
--maxfail=20 -m "not multiaccelerator" \
|
||||
tests/pallas/ops_test.py \
|
||||
tests/pallas/export_back_compat_pallas_test.py \
|
||||
tests/pallas/export_pallas_test.py \
|
||||
tests/pallas/tpu_ops_test.py \
|
||||
tests/pallas/tpu_pallas_test.py \
|
||||
tests/pallas/tpu_pallas_random_test.py \
|
||||
tests/pallas/tpu_pallas_async_test.py \
|
||||
tests/pallas/tpu_pallas_state_test.py
|
||||
|
||||
# Run Pallas printing tests, which need to run with I/O capturing disabled.
|
||||
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
|
||||
|
||||
# Run multi-accelerator across all chips
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py
|
||||
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" \
|
||||
tests/pjit_test.py \
|
||||
tests/pallas/tpu_pallas_distributed_test.py
|
@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from jax import version
|
||||
from jax._src import config
|
||||
from jax._src import hardware_utils
|
||||
@ -98,3 +100,22 @@ def cloud_tpu_init() -> None:
|
||||
'jax_pjrt_client_create_options',
|
||||
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
|
||||
)
|
||||
|
||||
|
||||
def is_cloud_tpu_older_than(year: int, month: int, day: int):
|
||||
# We import locally because the functions above must run before the runtime
|
||||
# modules are imported.
|
||||
from jax._src import xla_bridge # type: ignore
|
||||
date = datetime.date(year, month, day)
|
||||
if not running_in_cloud_tpu_vm:
|
||||
return False
|
||||
# The format of Cloud TPU platform_version is like:
|
||||
# PJRT C API
|
||||
# TFRT TPU v2
|
||||
# Built on Oct 30 2023 03:04:42 (1698660263) cl/577737722
|
||||
platform_version = xla_bridge.get_backend().platform_version.split('\n')[-1]
|
||||
results = re.findall(r'\(.*?\)', platform_version)
|
||||
if len(results) != 1:
|
||||
return True
|
||||
build_date = date.fromtimestamp(int(results[0][1:-1]))
|
||||
return build_date < date
|
||||
|
@ -39,6 +39,7 @@ from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import traceback_util
|
||||
from jax._src.cloud_tpu_init import is_cloud_tpu_older_than
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -597,6 +598,11 @@ def lower_jaxpr_to_module(
|
||||
for_verification: bool = False,
|
||||
dynamic_shape_replacement_enabled: bool = False,
|
||||
) -> tuple[Module, tuple[Any, ...]]:
|
||||
# NOTE: We should bump this periodically
|
||||
if is_cloud_tpu_older_than(2025, 1, 10):
|
||||
raise RuntimeError(
|
||||
"Pallas TPU requires a libTPU version that's at most a month old"
|
||||
)
|
||||
if dynamic_shape_replacement_enabled:
|
||||
_mosaic_lowering_dynamic_shape_env = LoweringDynamicShapeEnv()
|
||||
|
||||
@ -3399,12 +3405,15 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
|
||||
src, _ = _transform_ref(src, src_aval.dtype, src_aval.shape, src_transforms)
|
||||
dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms)
|
||||
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
|
||||
if ctx.forward_compatible:
|
||||
# TODO(mvoz): Remove once a month has passed. b/395630795
|
||||
if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 2, 12):
|
||||
# TODO(mvoz): Remove once six months have passed. b/395630795
|
||||
if hasattr(src_aval, "memory_space"):
|
||||
src_memory_space = _memory_space_to_mosaic_attribute(src_aval.memory_space)
|
||||
smem_space = ir.Attribute.parse("#tpu.memory_space<smem>")
|
||||
src_is_smem = src_memory_space == smem_space
|
||||
wait_ref = src if src_is_smem else dst
|
||||
else:
|
||||
wait_ref = dst
|
||||
# Legacy instruction emits only an sfence if the target/dst ref is in smem.
|
||||
# So, we pass the src ref to the wait instruction if it is in smem to
|
||||
# ensure legacy cases are correct, while technically keeping API compat.
|
||||
|
Loading…
x
Reference in New Issue
Block a user