mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move multi_platform_export_test.py out of jax2tf.
This test is now independent of jax2tf. Move it out and rename it export_harnesses_multi_platform_test.py. We disable the test in GitHub CI, because it is very large, pending some changes to ensure it parallelizes well. The test is still running in internal CI. This is matching the current behavior, since jax2tf tests are only run internally. PiperOrigin-RevId: 583603863
This commit is contained in:
parent
19bc9a2223
commit
3601b25899
1
.github/workflows/ci-build.yaml
vendored
1
.github/workflows/ci-build.yaml
vendored
@ -118,6 +118,7 @@ jobs:
|
||||
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
|
||||
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
|
||||
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
||||
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
||||
pytest -n auto --tb=short --maxfail=20 tests examples
|
||||
|
||||
|
||||
|
@ -2673,8 +2673,6 @@ for dtype in (np.float32, np.float64):
|
||||
|
||||
def wrap_and_split():
|
||||
key = jax.random.key(42)
|
||||
if config.enable_custom_prng.value:
|
||||
key = jax.random.wrap_key_data(key)
|
||||
result = jax.random.split(key, 2)
|
||||
return jax.random.key_data(result)
|
||||
|
||||
|
20
tests/BUILD
20
tests/BUILD
@ -1287,6 +1287,26 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "export_harnesses_multi_platform_test",
|
||||
srcs = ["export_harnesses_multi_platform_test.py"],
|
||||
disable_configs = [
|
||||
"gpu_a100", # TODO(b/269593297): matmul precision issues
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 20,
|
||||
"tpu": 20,
|
||||
},
|
||||
tags = [
|
||||
"nodebug", # Times out.
|
||||
],
|
||||
deps = [
|
||||
"//jax:internal_test_harnesses",
|
||||
"//jax/experimental/export",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
@ -11,11 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for multi-platform and cross-platform JAX export."""
|
||||
"""Tests for multi-platform and cross-platform JAX export.
|
||||
|
||||
This module contains the tests parameterized by test_harnesses. These tests
|
||||
verify that the primitive lowering rules work properly in multi-platform and
|
||||
cross-platform lowering mode. The actual mechanism for multi-platform and
|
||||
cross-platform lowering is tested in export_test.py.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Callable, Sequence
|
||||
from typing import Callable
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
@ -26,7 +32,6 @@ import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.export import export
|
||||
# TODO(necula): This test does not depend on jax2tf, move it out.
|
||||
from jax._src.internal_test_util import test_harnesses
|
||||
|
||||
|
||||
@ -79,6 +84,12 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
include_jax_unimpl=False,
|
||||
#one_containing="",
|
||||
)
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning,
|
||||
message=("Using reduced precision for gradient of reduce-window min/max "
|
||||
"operator to work around missing XLA support for pair-reductions")
|
||||
)
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_prim(self, harness: test_harnesses.Harness):
|
||||
if (jtu.device_under_test() == "gpu"
|
||||
and _known_failures_gpu.search(harness.fullname)):
|
Loading…
x
Reference in New Issue
Block a user