Move multi_platform_export_test.py out of jax2tf.

This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py.

We disable the test in GitHub CI, because it is very large, pending
some changes to ensure it parallelizes well. The test is still
running in internal CI. This is matching the current behavior, since
jax2tf tests are only run internally.

PiperOrigin-RevId: 583603863
This commit is contained in:
George Necula 2023-11-18 02:52:06 -08:00 committed by jax authors
parent 19bc9a2223
commit 3601b25899
4 changed files with 35 additions and 5 deletions

View File

@ -118,6 +118,7 @@ jobs:
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 tests examples

View File

@ -2673,8 +2673,6 @@ for dtype in (np.float32, np.float64):
def wrap_and_split():
key = jax.random.key(42)
if config.enable_custom_prng.value:
key = jax.random.wrap_key_data(key)
result = jax.random.split(key, 2)
return jax.random.key_data(result)

View File

@ -1287,6 +1287,26 @@ jax_test(
],
)
jax_test(
name = "export_harnesses_multi_platform_test",
srcs = ["export_harnesses_multi_platform_test.py"],
disable_configs = [
"gpu_a100", # TODO(b/269593297): matmul precision issues
],
shard_count = {
"cpu": 40,
"gpu": 20,
"tpu": 20,
},
tags = [
"nodebug", # Times out.
],
deps = [
"//jax:internal_test_harnesses",
"//jax/experimental/export",
],
)
exports_files(
[
"api_test.py",

View File

@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multi-platform and cross-platform JAX export."""
"""Tests for multi-platform and cross-platform JAX export.
This module contains the tests parameterized by test_harnesses. These tests
verify that the primitive lowering rules work properly in multi-platform and
cross-platform lowering mode. The actual mechanism for multi-platform and
cross-platform lowering is tested in export_test.py.
"""
import math
import re
from typing import Callable, Sequence
from typing import Callable
from absl import logging
from absl.testing import absltest
@ -26,7 +32,6 @@ import jax
from jax import lax
from jax._src import test_util as jtu
from jax.experimental.export import export
# TODO(necula): This test does not depend on jax2tf, move it out.
from jax._src.internal_test_util import test_harnesses
@ -79,6 +84,12 @@ class PrimitiveTest(jtu.JaxTestCase):
include_jax_unimpl=False,
#one_containing="",
)
@jtu.ignore_warning(
category=UserWarning,
message=("Using reduced precision for gradient of reduce-window min/max "
"operator to work around missing XLA support for pair-reductions")
)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_prim(self, harness: test_harnesses.Harness):
if (jtu.device_under_test() == "gpu"
and _known_failures_gpu.search(harness.fullname)):