From 3601b25899ee2241efe8eac6e5318bb366010488 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 18 Nov 2023 02:52:06 -0800 Subject: [PATCH] Move multi_platform_export_test.py out of jax2tf. This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py. We disable the test in GitHub CI, because it is very large, pending some changes to ensure it parallelizes well. The test is still running in internal CI. This is matching the current behavior, since jax2tf tests are only run internally. PiperOrigin-RevId: 583603863 --- .github/workflows/ci-build.yaml | 1 + jax/_src/internal_test_util/test_harnesses.py | 2 -- tests/BUILD | 20 +++++++++++++++++++ .../export_harnesses_multi_platform_test.py | 17 +++++++++++++--- 4 files changed, 35 insertions(+), 5 deletions(-) rename jax/experimental/jax2tf/tests/multi_platform_export_test.py => tests/export_harnesses_multi_platform_test.py (89%) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 3d8b58e4c..4533e09c9 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -118,6 +118,7 @@ jobs: echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG" echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" + echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" pytest -n auto --tb=short --maxfail=20 tests examples diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 4072f9040..728e80135 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -2673,8 +2673,6 @@ for dtype in (np.float32, np.float64): def wrap_and_split(): key = jax.random.key(42) - if config.enable_custom_prng.value: - key = jax.random.wrap_key_data(key) result = jax.random.split(key, 2) return jax.random.key_data(result) diff --git a/tests/BUILD b/tests/BUILD index 79f91b5b2..2ef0a96c4 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1287,6 +1287,26 @@ jax_test( ], ) +jax_test( + name = "export_harnesses_multi_platform_test", + srcs = ["export_harnesses_multi_platform_test.py"], + disable_configs = [ + "gpu_a100", # TODO(b/269593297): matmul precision issues + ], + shard_count = { + "cpu": 40, + "gpu": 20, + "tpu": 20, + }, + tags = [ + "nodebug", # Times out. + ], + deps = [ + "//jax:internal_test_harnesses", + "//jax/experimental/export", + ], +) + exports_files( [ "api_test.py", diff --git a/jax/experimental/jax2tf/tests/multi_platform_export_test.py b/tests/export_harnesses_multi_platform_test.py similarity index 89% rename from jax/experimental/jax2tf/tests/multi_platform_export_test.py rename to tests/export_harnesses_multi_platform_test.py index 1913773eb..e388c3c35 100644 --- a/jax/experimental/jax2tf/tests/multi_platform_export_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for multi-platform and cross-platform JAX export.""" +"""Tests for multi-platform and cross-platform JAX export. + +This module contains the tests parameterized by test_harnesses. These tests +verify that the primitive lowering rules work properly in multi-platform and +cross-platform lowering mode. The actual mechanism for multi-platform and +cross-platform lowering is tested in export_test.py. +""" import math import re -from typing import Callable, Sequence +from typing import Callable from absl import logging from absl.testing import absltest @@ -26,7 +32,6 @@ import jax from jax import lax from jax._src import test_util as jtu from jax.experimental.export import export -# TODO(necula): This test does not depend on jax2tf, move it out. from jax._src.internal_test_util import test_harnesses @@ -79,6 +84,12 @@ class PrimitiveTest(jtu.JaxTestCase): include_jax_unimpl=False, #one_containing="", ) + @jtu.ignore_warning( + category=UserWarning, + message=("Using reduced precision for gradient of reduce-window min/max " + "operator to work around missing XLA support for pair-reductions") + ) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_prim(self, harness: test_harnesses.Harness): if (jtu.device_under_test() == "gpu" and _known_failures_gpu.search(harness.fullname)):