From 70f91db853f5aa6fc353063a1ca1d5a36a73c379 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 12:28:32 -0700 Subject: [PATCH] Set PYTHONWARNINGS=error in bazel tests. The goal of this change is to catch PRs that introduce new warnings sooner. To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable. Add code to suppress some new warnings uncovered in CI. PiperOrigin-RevId: 678352286 --- benchmarks/mosaic/BUILD | 4 +- docs/cuda_custom_call/BUILD | 4 +- .../jax2tf/examples/keras_reuse_main.py | 5 +- .../jax2tf/examples/keras_reuse_main_test.py | 1 + jax/experimental/jax2tf/examples/mnist_lib.py | 5 +- jax/experimental/jax2tf/tests/call_tf_test.py | 34 +++ jax/experimental/jax2tf/tests/jax2tf_test.py | 27 ++ .../jax2tf/tests/savedmodel_test.py | 11 + .../jax2tf/tests/shape_poly_test.py | 6 +- .../jax2tf/tests/sharding_test.py | 12 +- jax/experimental/mosaic/gpu/examples/BUILD | 6 +- jaxlib/jax.bzl | 12 +- jaxlib/tools/BUILD.bazel | 4 +- pyproject.toml | 5 + tests/BUILD | 268 +++++++++--------- tests/array_interoperability_test.py | 6 + tests/host_callback_to_tf_test.py | 8 + tests/lax_numpy_test.py | 6 +- tests/lax_test.py | 1 + tests/mosaic/BUILD | 10 +- tests/mosaic/gpu_test.py | 3 + tests/pallas/BUILD | 50 ++-- tests/pmap_test.py | 4 + tests/pytorch_interoperability_test.py | 4 + tests/sparse_bcoo_bcsr_test.py | 1 + 25 files changed, 316 insertions(+), 181 deletions(-) diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 72aae09af..4345e620a 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -42,7 +42,7 @@ DISABLED_CONFIGS = [ "gpu_pjrt_c_api", ] -jax_test( +jax_multiplatform_test( name = "matmul_bench", srcs = ["matmul_bench.py"], disable_backends = DISABLED_BACKENDS, diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD index 0591eed1f..0089b6b9f 100644 --- a/docs/cuda_custom_call/BUILD +++ b/docs/cuda_custom_call/BUILD @@ -16,7 +16,7 @@ load( "//jaxlib:jax.bzl", "cuda_library", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", ) licenses(["notice"]) @@ -28,7 +28,7 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "cuda_custom_call_test", srcs = ["cuda_custom_call_test.py"], data = [":foo"], diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main.py b/jax/experimental/jax2tf/examples/keras_reuse_main.py index 77f882af6..1806e8c45 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main.py @@ -18,13 +18,16 @@ Includes the flags from saved_model_main.py. See README.md. """ import logging +import warnings from absl import app from absl import flags from jax.experimental.jax2tf.examples import mnist_lib from jax.experimental.jax2tf.examples import saved_model_main import tensorflow as tf import tensorflow_datasets as tfds # type: ignore -import tensorflow_hub as hub # type: ignore +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import tensorflow_hub as hub # type: ignore FLAGS = flags.FLAGS diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py index 293484291..e34282a76 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py @@ -41,6 +41,7 @@ class KerasReuseMainTest(tf_test_util.JaxToTfTestCase): @parameterized.named_parameters( dict(testcase_name=f"_{model}", model=model) for model in ["mnist_pure_jax", "mnist_flax"]) + @jtu.ignore_warning(message="the imp module is deprecated") def test_keras_reuse(self, model="mnist_pure_jax"): FLAGS.model = model keras_reuse_main.main(None) diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 41173c79a..77432f9eb 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -27,6 +27,7 @@ import logging import re import time from typing import Any +import warnings from absl import flags import flax @@ -70,7 +71,9 @@ def load_mnist(split: tfds.Split, batch_size: int): if _MOCK_DATA.value: with tfds.testing.mock_data(num_examples=batch_size): try: - ds = tfds.load("mnist", split=split) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + ds = tfds.load("mnist", split=split) except Exception as e: m = re.search(r'metadata files were not found in (.+/)mnist/', str(e)) if m: diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index e10c3fbfd..492dfad4c 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -88,6 +88,17 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() @_parameterized_jit def test_eval_scalar_arg(self, with_jit=True): @@ -862,6 +873,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase): "Reloading output of jax2tf into JAX with call_tf" + def setUp(self): if tf is None: raise unittest.SkipTest("Test requires tensorflow") @@ -869,6 +881,17 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() def test_simple(self): f_jax = jnp.sin @@ -1157,6 +1180,17 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ef7a5ee2c..6411dc581 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -76,6 +76,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): super().setUpClass() + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_empty(self): f_jax = lambda x, y: x self.ConvertAndCompare(f_jax, 0.7, 1) @@ -1621,6 +1632,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): res = jax2tf.convert(f_jax, native_serialization=True)(*many_args) self.assertAllClose(f_jax(*many_args), res) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def test_nested_convert(self): # Test call sequence: convert -> call_tf -> convert. @@ -1677,6 +1690,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): @jtu.with_config(jax_enable_custom_prng=True) class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_key_argument(self): func = lambda key: jax.random.uniform(key, ()) key = jax.random.PRNGKey(0) @@ -1709,6 +1733,9 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): self.use_max_serialization_version = False super().setUp() + @jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) def test_simple(self): self.ConvertAndCompare(jnp.sin, 0.7) diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index bc19915d1..aee158833 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -30,6 +30,17 @@ jax.config.parse_flags_with_absl() class SavedModelTest(tf_test_util.JaxToTfTestCase): + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_eval(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) model = tf.Module() diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index a9ee17762..07bd9b5ae 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -334,7 +334,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): check_shape_poly(self, f_jax, arg_descriptors=[x], polymorphic_shapes=["b"]) - @jtu.parameterized_filterable( kwargs=[ dict(testcase_name=f"expr={name}", expr=expr) @@ -941,7 +940,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): xi_yf = (xi, yf) zb = np.array([True, False], dtype=np.bool_) def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2] - # results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4] + # results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4] xi, yf = xi_yf # Return a tuple: # (1) float constant, with 0 tangent; @@ -1032,6 +1031,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): f_tf, input_signature=[tf.TensorSpec([None], x.dtype)]) self.assertAllClose(f_jax(x), restored_f(x)) + @jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) def test_readme_examples(self): """Some of the examples from the README.""" diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 9009c1586..247135395 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -61,7 +61,8 @@ def setUpModule(): global topology if jtu.test_device_matches(["tpu"]): - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + with jtu.ignore_warning(message="the imp module is deprecated"): + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(resolver) # Do TPU init at beginning since it will wipe out all HBMs. topology = tf.tpu.experimental.initialize_tpu_system(resolver) @@ -84,6 +85,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): raise unittest.SkipTest("Test requires at least 2 local devices") self.devices = np.array(jax.devices()[:2]) # use 2 devices + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def log_jax_hlo(self, f_jax, args: Sequence[Any], *, num_replicas=1, num_partitions=2): """Log the HLO generated from JAX before and after optimizations""" diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 6f5af51fb..57f78cb2c 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library", "py_test") -load("//jaxlib:jax.bzl", "py_deps") +load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "jax_py_test", "py_deps") licenses(["notice"]) @@ -48,7 +48,7 @@ py_library( ], ) -py_test( +jax_py_test( name = "run_matmul", srcs = ["matmul.py"], main = "matmul.py", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index cf9047cc4..65ec572c7 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -19,6 +19,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") +load("@rules_python//python:defs.bzl", "py_test") load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") @@ -222,7 +223,7 @@ def if_building_jaxlib( }) # buildifier: disable=function-docstring -def jax_test( +def jax_multiplatform_test( name, srcs, args = [], @@ -300,3 +301,12 @@ jax_test_file_visibility = [] def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable pass + +def jax_py_test( + name, + env = {}, + **kwargs): + env = dict(env) + if "PYTHONWARNINGS" not in env: + env["PYTHONWARNINGS"] = "error" + py_test(name = name, env = env, **kwargs) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 4642af120..4553dc1e3 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -16,7 +16,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows") +load("//jaxlib:jax.bzl", "if_windows", "jax_py_test") licenses(["notice"]) # Apache 2 @@ -52,7 +52,7 @@ py_binary( ], ) -py_test( +jax_py_test( name = "build_wheel_test", srcs = ["build_wheel_test.py"], data = [":build_wheel"], diff --git a/pyproject.toml b/pyproject.toml index b629762fe..9ce13ea50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,11 @@ filterwarnings = [ "default:Special cases found for .* but none were parsed.*:UserWarning", "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", "default:The .* method is good for exploring strategies.*", + + # NOTE: this is probably not where you want to add code to suppress a + # warning. Only pytest tests look at this list, whereas Bazel tests also + # check for warnings and do not check this list. Most likely, you should + # add a @jtu.ignore_warning decorator to your test instead. ] doctest_optionflags = [ "NUMBER", diff --git a/tests/BUILD b/tests/BUILD index e64889cc3..49dbf0512 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_test") load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", + "jax_py_test", "jax_test_file_visibility", "py_deps", "pytype_test", @@ -31,29 +31,29 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], shard_count = 10, ) -jax_test( +jax_multiplatform_test( name = "device_test", srcs = ["device_test.py"], ) -jax_test( +jax_multiplatform_test( name = "dynamic_api_test", srcs = ["dynamic_api_test.py"], shard_count = 2, ) -jax_test( +jax_multiplatform_test( name = "api_util_test", srcs = ["api_util_test.py"], ) -py_test( +jax_py_test( name = "array_api_test", srcs = ["array_api_test.py"], deps = [ @@ -63,7 +63,7 @@ py_test( ] + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], disable_backends = ["tpu"], @@ -71,7 +71,7 @@ jax_test( deps = py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "batching_test", srcs = ["batching_test.py"], shard_count = { @@ -79,12 +79,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "config_test", srcs = ["config_test.py"], ) -jax_test( +jax_multiplatform_test( name = "core_test", srcs = ["core_test.py"], shard_count = { @@ -93,17 +93,17 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_object_test", srcs = ["custom_object_test.py"], ) -jax_test( +jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], ) -py_test( +jax_py_test( name = "multiprocess_gpu_test", srcs = ["multiprocess_gpu_test.py"], args = [ @@ -116,12 +116,12 @@ py_test( ] + py_deps("portpicker"), ) -jax_test( +jax_multiplatform_test( name = "dtypes_test", srcs = ["dtypes_test.py"], ) -jax_test( +jax_multiplatform_test( name = "errors_test", srcs = ["errors_test.py"], # No need to test all other configs. @@ -130,13 +130,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "extend_test", srcs = ["extend_test.py"], deps = ["//jax:extend"], ) -jax_test( +jax_multiplatform_test( name = "fft_test", srcs = ["fft_test.py"], backend_tags = { @@ -152,12 +152,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "generated_fun_test", srcs = ["generated_fun_test.py"], ) -jax_test( +jax_multiplatform_test( name = "gpu_memory_flags_test_no_preallocation", srcs = ["gpu_memory_flags_test.py"], disable_backends = [ @@ -170,7 +170,7 @@ jax_test( main = "gpu_memory_flags_test.py", ) -jax_test( +jax_multiplatform_test( name = "gpu_memory_flags_test", srcs = ["gpu_memory_flags_test.py"], disable_backends = [ @@ -182,7 +182,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, @@ -196,7 +196,7 @@ jax_test( ] + py_deps("matplotlib"), ) -jax_test( +jax_multiplatform_test( name = "svd_test", srcs = ["svd_test.py"], shard_count = { @@ -206,7 +206,7 @@ jax_test( }, ) -py_test( +jax_py_test( name = "xla_interpreter_test", srcs = ["xla_interpreter_test.py"], deps = [ @@ -215,7 +215,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "memories_test", srcs = ["memories_test.py"], shard_count = { @@ -226,7 +226,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pjit_test", srcs = ["pjit_test.py"], backend_tags = { @@ -249,7 +249,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "layout_test", srcs = ["layout_test.py"], backend_tags = { @@ -258,7 +258,7 @@ jax_test( tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], deps = [ @@ -266,7 +266,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pgle_test", srcs = ["pgle_test.py"], backend_tags = { @@ -286,7 +286,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], disable_backends = [ @@ -301,7 +301,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "array_test", srcs = ["array_test.py"], backend_tags = { @@ -314,7 +314,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "aot_test", srcs = ["aot_test.py"], tags = ["multiaccelerator"], @@ -323,7 +323,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "image_test", srcs = ["image_test.py"], shard_count = { @@ -335,7 +335,7 @@ jax_test( deps = py_deps("pil") + py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], deps = [ @@ -343,13 +343,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", ) -py_test( +jax_py_test( name = "jax_to_ir_test", srcs = ["jax_to_ir_test.py"], deps = [ @@ -359,7 +359,7 @@ py_test( ] + py_deps("tensorflow_core"), ) -py_test( +jax_py_test( name = "jaxpr_util_test", srcs = ["jaxpr_util_test.py"], deps = [ @@ -369,7 +369,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "jet_test", srcs = ["jet_test.py"], shard_count = { @@ -382,7 +382,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "lax_control_flow_test", srcs = ["lax_control_flow_test.py"], shard_count = { @@ -392,7 +392,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], shard_count = { @@ -402,7 +402,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], shard_count = { @@ -412,7 +412,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { @@ -429,7 +429,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_operators_test", srcs = ["lax_numpy_operators_test.py"], shard_count = { @@ -439,7 +439,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_reducers_test", srcs = ["lax_numpy_reducers_test.py"], shard_count = { @@ -449,7 +449,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_indexing_test", srcs = ["lax_numpy_indexing_test.py"], shard_count = { @@ -459,7 +459,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], shard_count = { @@ -469,7 +469,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_ufuncs_test", srcs = ["lax_numpy_ufuncs_test.py"], shard_count = { @@ -479,12 +479,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_vectorize_test", srcs = ["lax_numpy_vectorize_test.py"], ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], shard_count = { @@ -495,7 +495,7 @@ jax_test( deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_sparse_test", srcs = ["lax_scipy_sparse_test.py"], backend_tags = { @@ -508,7 +508,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_special_functions_test", srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { @@ -522,7 +522,7 @@ jax_test( deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { @@ -535,7 +535,7 @@ jax_test( ] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_test", srcs = ["lax_test.py"], backend_tags = { @@ -552,7 +552,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "lax_metal_test", srcs = ["lax_metal_test.py"], disable_backends = [ @@ -567,7 +567,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"], shard_count = { @@ -577,7 +577,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_vmap_test", srcs = ["lax_vmap_test.py"], shard_count = { @@ -588,7 +588,7 @@ jax_test( deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_vmap_op_test", srcs = ["lax_vmap_op_test.py"], shard_count = { @@ -599,7 +599,7 @@ jax_test( deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) -py_test( +jax_py_test( name = "lazy_loader_test", srcs = [ "lazy_loader_test.py", @@ -610,7 +610,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "deprecation_test", srcs = [ "deprecation_test.py", @@ -621,7 +621,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "linalg_test", srcs = ["linalg_test.py"], backend_tags = { @@ -640,12 +640,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], ) -jax_test( +jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], disable_backends = [ @@ -654,7 +654,7 @@ jax_test( ], ) -py_test( +jax_py_test( name = "monitoring_test", srcs = ["monitoring_test.py"], deps = [ @@ -663,12 +663,12 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], ) -jax_test( +jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], disable_backends = [ @@ -677,7 +677,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "nn_test", srcs = ["nn_test.py"], backend_tags = { @@ -695,13 +695,13 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "optimizers_test", srcs = ["optimizers_test.py"], deps = ["//jax:optimizers"], ) -jax_test( +jax_multiplatform_test( name = "pickle_test", srcs = ["pickle_test.py"], deps = [ @@ -709,7 +709,7 @@ jax_test( ] + py_deps("cloudpickle") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pmap_test", srcs = ["pmap_test.py"], backend_tags = { @@ -729,7 +729,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "polynomial_test", srcs = ["polynomial_test.py"], # No implementation of nonsymmetric Eigendecomposition. @@ -749,7 +749,7 @@ jax_test( tags = ["nomsan"], ) -jax_test( +jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], disable_backends = [ @@ -758,7 +758,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], disable_backends = [ @@ -767,7 +767,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], disable_backends = ["tpu"], @@ -786,7 +786,7 @@ jax_test( deps = py_deps("torch"), ) -jax_test( +jax_multiplatform_test( name = "qdwh_test", srcs = ["qdwh_test.py"], backend_tags = { @@ -799,7 +799,7 @@ jax_test( shard_count = 10, ) -jax_test( +jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], backend_tags = { @@ -821,7 +821,7 @@ jax_test( tags = ["noasan"], # Times out ) -jax_test( +jax_multiplatform_test( name = "random_lax_test", srcs = ["random_lax_test.py"], backend_tags = { @@ -847,7 +847,7 @@ jax_test( ) # TODO(b/199564969): remove once we always enable_custom_prng -jax_test( +jax_multiplatform_test( name = "random_test_with_custom_prng", srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], @@ -872,7 +872,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "scipy_fft_test", srcs = ["scipy_fft_test.py"], backend_tags = { @@ -885,22 +885,22 @@ jax_test( shard_count = 4, ) -jax_test( +jax_multiplatform_test( name = "scipy_interpolate_test", srcs = ["scipy_interpolate_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_ndimage_test", srcs = ["scipy_ndimage_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_optimize_test", srcs = ["scipy_optimize_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_signal_test", srcs = ["scipy_signal_test.py"], backend_tags = { @@ -925,13 +925,13 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "scipy_spatial_test", srcs = ["scipy_spatial_test.py"], deps = py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "scipy_stats_test", srcs = ["scipy_stats_test.py"], backend_tags = { @@ -948,7 +948,7 @@ jax_test( ], # Times out ) -jax_test( +jax_multiplatform_test( name = "sparse_test", srcs = ["sparse_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -981,7 +981,7 @@ jax_test( ] + py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "sparse_bcoo_bcsr_test", srcs = ["sparse_bcoo_bcsr_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -1014,7 +1014,7 @@ jax_test( ] + py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "sparse_nm_test", srcs = ["sparse_nm_test.py"], config_tags_overrides = { @@ -1037,7 +1037,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "sparsify_test", srcs = ["sparsify_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -1061,12 +1061,12 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "stack_test", srcs = ["stack_test.py"], ) -jax_test( +jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], shard_count = { @@ -1075,7 +1075,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], shard_count = { @@ -1085,18 +1085,18 @@ jax_test( deps = ["//jax:stax"], ) -jax_test( +jax_multiplatform_test( name = "linear_search_test", srcs = ["third_party/scipy/line_search_test.py"], main = "third_party/scipy/line_search_test.py", ) -jax_test( +jax_multiplatform_test( name = "blocked_sampler_test", srcs = ["blocked_sampler_test.py"], ) -py_test( +jax_py_test( name = "tree_util_test", srcs = ["tree_util_test.py"], deps = [ @@ -1114,7 +1114,7 @@ pytype_test( ], ) -py_test( +jax_py_test( name = "util_test", srcs = ["util_test.py"], deps = [ @@ -1123,7 +1123,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "version_test", srcs = ["version_test.py"], deps = [ @@ -1132,7 +1132,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "xla_bridge_test", srcs = ["xla_bridge_test.py"], data = ["testdata/example_pjrt_plugin_config.json"], @@ -1143,7 +1143,7 @@ py_test( ] + py_deps("absl/logging"), ) -py_test( +jax_py_test( name = "lru_cache_test", srcs = ["lru_cache_test.py"], deps = [ @@ -1153,7 +1153,7 @@ py_test( ] + py_deps("filelock"), ) -jax_test( +jax_multiplatform_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], deps = [ @@ -1162,7 +1162,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "cache_key_test", srcs = ["cache_key_test.py"], deps = [ @@ -1171,7 +1171,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "ode_test", srcs = ["ode_test.py"], shard_count = { @@ -1180,7 +1180,7 @@ jax_test( deps = ["//jax:ode"], ) -jax_test( +jax_multiplatform_test( name = "host_callback_outfeed_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=true"], @@ -1197,7 +1197,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "host_callback_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=false"], @@ -1213,7 +1213,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "host_callback_to_tf_test", srcs = ["host_callback_to_tf_test.py"], tags = ["noasan"], # Linking TF causes a linker OOM. @@ -1223,12 +1223,12 @@ jax_test( ] + py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], ) -jax_test( +jax_multiplatform_test( name = "x64_context_test", srcs = ["x64_context_test.py"], deps = [ @@ -1236,13 +1236,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "ann_test", srcs = ["ann_test.py"], shard_count = 10, ) -py_test( +jax_py_test( name = "mesh_utils_test", srcs = ["mesh_utils_test.py"], deps = [ @@ -1252,17 +1252,17 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "transfer_guard_test", srcs = ["transfer_guard_test.py"], ) -jax_test( +jax_multiplatform_test( name = "name_stack_test", srcs = ["name_stack_test.py"], ) -jax_test( +jax_multiplatform_test( name = "jaxpr_effects_test", srcs = ["jaxpr_effects_test.py"], backend_tags = { @@ -1275,7 +1275,7 @@ jax_test( tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], enable_configs = [ @@ -1284,7 +1284,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "python_callback_test", srcs = ["python_callback_test.py"], backend_tags = { @@ -1296,7 +1296,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], enable_configs = [ @@ -1305,7 +1305,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "state_test", srcs = ["state_test.py"], # Use fewer cases to prevent timeouts. @@ -1327,12 +1327,12 @@ jax_test( deps = py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "mutable_array_test", srcs = ["mutable_array_test.py"], ) -jax_test( +jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { @@ -1342,7 +1342,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], shard_count = { @@ -1362,12 +1362,12 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "clear_backends_test", srcs = ["clear_backends_test.py"], ) -jax_test( +jax_multiplatform_test( name = "attrs_test", srcs = ["attrs_test.py"], deps = [ @@ -1375,7 +1375,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], disable_backends = [ @@ -1391,7 +1391,7 @@ jax_test( ], ) -py_test( +jax_py_test( name = "mosaic_test", srcs = ["mosaic_test.py"], deps = [ @@ -1401,7 +1401,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "source_info_test", srcs = ["source_info_test.py"], deps = [ @@ -1410,7 +1410,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "package_structure_test", srcs = ["package_structure_test.py"], deps = [ @@ -1419,12 +1419,12 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "logging_test", srcs = ["logging_test.py"], ) -jax_test( +jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], enable_configs = [ @@ -1436,7 +1436,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "shape_poly_test", srcs = ["shape_poly_test.py"], disable_configs = [ @@ -1461,7 +1461,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_harnesses_multi_platform_test", srcs = ["export_harnesses_multi_platform_test.py"], disable_configs = [ @@ -1484,7 +1484,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_back_compat_test", srcs = ["export_back_compat_test.py"], tags = [], @@ -1494,7 +1494,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], disable_backends = [ @@ -1507,13 +1507,13 @@ jax_test( tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "xla_metadata_test", srcs = ["xla_metadata_test.py"], deps = ["//jax:experimental"], ) -py_test( +jax_py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], deps = [ @@ -1522,7 +1522,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "sourcemap_test", srcs = ["sourcemap_test.py"], deps = [ @@ -1531,7 +1531,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], disable_backends = [ diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 356024153..02f5ad527 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -75,6 +75,8 @@ class DLPackTest(jtu.JaxTestCase): use_stream=[False, True], ) @jtu.run_on_devices("gpu") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) @@ -142,6 +144,8 @@ class DLPackTest(jtu.JaxTestCase): dtype=dlpack_dtypes, ) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTensorFlowToJax(self, shape, dtype): if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): @@ -184,6 +188,8 @@ class DLPackTest(jtu.JaxTestCase): self.assertAllClose(np, y.numpy()) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTensorFlowToJaxInt64(self): # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py index fe80c90ac..3a36ce129 100644 --- a/tests/host_callback_to_tf_test.py +++ b/tests/host_callback_to_tf_test.py @@ -176,6 +176,8 @@ class CallToTFTest(jtu.JaxTestCase): testcase_name=f"_{ad=}", ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys()) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_impl(self, ad="simple"): self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] @@ -197,6 +199,8 @@ class CallToTFTest(jtu.JaxTestCase): ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys() if ad != "none") + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_grad(self, ad="simple"): self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] @@ -217,6 +221,8 @@ class CallToTFTest(jtu.JaxTestCase): self.assertAllClose(jax.grad(f_jax)(x), grad_f, check_dtypes=False) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_grad_pytree(self): self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad @@ -246,6 +252,8 @@ class CallToTFTest(jtu.JaxTestCase): testcase_name=f"_degree=_{degree}", degree=degree) for degree in [1, 2, 3, 4]) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_higher_order_grad(self, degree=4): self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 371a13f0c..a10a73697 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2681,6 +2681,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): shape=all_shapes, dtype=default_dtypes, ) + @jtu.ignore_warning(category=RuntimeWarning, message="overflow") def testFrexp(self, shape, dtype, rng_factory): # integer types are converted to float64 in numpy's implementation if (dtype not in [jnp.bfloat16, np.float16, np.float32] @@ -6270,7 +6271,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]: for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) try: - with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"): + with jtu.ignore_warning( + category=RuntimeWarning, message="(divide by zero|invalid value)"): _ = func(*args) except TypeError: pass @@ -6292,7 +6294,7 @@ class NumpyUfuncTests(jtu.JaxTestCase): jnp_op = getattr(jnp, name) np_op = getattr(np, name) np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) + message="(divide by zero|invalid value)")(np_op) args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes) with jtu.strict_promotion_if_dtypes_match(arg_dtypes): diff --git a/tests/lax_test.py b/tests/lax_test.py index 3f43773a8..d82b35c6b 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -110,6 +110,7 @@ class LaxTest(jtu.JaxTestCase): for shape_group in lax_test_util.compatible_shapes), dtype=rec.dtypes) for rec in lax_test_util.lax_ops())) + @jtu.ignore_warning(message="invalid value", category=RuntimeWarning) def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol): if (not config.enable_x64.value and op_name == "nextafter" and dtype == np.float64): diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 6e5c94982..9eadc08d4 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -43,7 +43,7 @@ DISABLED_CONFIGS = [ "gpu", ] -jax_test( +jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], config_tags_overrides = { @@ -63,7 +63,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], disable_backends = DISABLED_BACKENDS, @@ -75,7 +75,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "flash_attention", srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], disable_backends = DISABLED_BACKENDS, @@ -87,7 +87,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "flash_attention_test", srcs = ["flash_attention_test.py"], disable_backends = DISABLED_BACKENDS, diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 2eacf7c99..30f830c31 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1200,6 +1200,8 @@ class FragmentedArrayTest(TestCase): m=(64, 128), n=(8, 16, 32, 64, 80, 128, 256), ) + @jtu.ignore_warning(message="(invalid value|divide by zero)", + category=RuntimeWarning) def test_binary(self, op, dtype, m=64, n=32): if isinstance(op, tuple): op, np_op = op @@ -1294,6 +1296,7 @@ class FragmentedArrayTest(TestCase): ], approx=[False, True], ) + @jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning) def test_math(self, ops, approx, m=64, n=32): op, np_op = ops def kernel(ctx, dst, _): diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 6804d9167..e535f1f59 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -28,7 +28,7 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "pallas_test", srcs = [ "pallas_test.py", @@ -62,7 +62,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pallas_jumble_test", srcs = [ "pallas_jumble_test.py", @@ -85,7 +85,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "ops_test", srcs = [ "ops_test.py", @@ -125,7 +125,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "indexing_test", srcs = [ "indexing_test.py", @@ -144,7 +144,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pallas_vmap_test", srcs = [ "pallas_vmap_test.py", @@ -176,7 +176,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "mosaic_gpu_test", srcs = [ "mosaic_gpu_test.py", @@ -213,7 +213,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "export_back_compat_pallas_test", srcs = ["export_back_compat_pallas_test.py"], config_tags_overrides = { @@ -244,7 +244,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_pallas_test", srcs = ["export_pallas_test.py"], config_tags_overrides = { @@ -272,7 +272,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pallas_shape_poly_test", srcs = ["pallas_shape_poly_test.py"], config_tags_overrides = { @@ -299,7 +299,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pallas_error_handling_test", srcs = [ "pallas_error_handling_test.py", @@ -317,7 +317,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_all_gather_test", srcs = [ "tpu_all_gather_test.py", @@ -331,7 +331,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_gmm_test", srcs = [ "tpu_gmm_test.py", @@ -356,7 +356,7 @@ jax_test( ]), ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_test", srcs = ["tpu_pallas_test.py"], # The flag is necessary for ``pl.debug_print`` tests to work on TPU. @@ -372,7 +372,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_ops_test", srcs = [ "tpu_ops_test.py", @@ -388,7 +388,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], disable_backends = [ @@ -402,7 +402,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], disable_backends = [ @@ -422,7 +422,7 @@ jax_test( ] + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_async_test", srcs = ["tpu_pallas_async_test.py"], disable_backends = [ @@ -436,7 +436,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_mesh_test", srcs = ["tpu_pallas_mesh_test.py"], disable_backends = [ @@ -454,7 +454,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_random_test", srcs = [ "tpu_pallas_random_test.py", @@ -472,7 +472,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], disable_backends = [ @@ -490,7 +490,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_splash_attention_kernel_test", srcs = [ "tpu_splash_attention_kernel_test.py", @@ -510,7 +510,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_splash_attention_mask_test", srcs = [ "tpu_splash_attention_mask_test.py", @@ -523,7 +523,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "gpu_attention_test", srcs = [ "gpu_attention_test.py", @@ -556,7 +556,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "gpu_ops_test", srcs = [ "gpu_ops_test.py", diff --git a/tests/pmap_test.py b/tests/pmap_test.py index d7dcc7ba3..9a8d0b912 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3209,8 +3209,12 @@ class EagerPmapMixin: self.jit_disabled = config.disable_jit.value config.update('jax_disable_jit', True) config.update('jax_eager_pmap', True) + self.warning_ctx = jtu.ignore_warning( + message="Some donated buffers were not usable", category=UserWarning) + self.warning_ctx.__enter__() def tearDown(self): + self.warning_ctx.__exit__(None, None, None) config.update('jax_eager_pmap', self.eager_pmap_enabled) config.update('jax_disable_jit', self.jit_disabled) super().tearDown() diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 2e0fc3223..e41c4329b 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -108,6 +108,8 @@ class DLPackTest(jtu.JaxTestCase): else: self.assertAllClose(np, y.cpu().numpy()) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTorchToJaxInt64(self): # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( @@ -116,6 +118,8 @@ class DLPackTest(jtu.JaxTestCase): self.assertEqual(x.dtype, dtype_expected) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTorchToJax(self, shape, dtype): if not config.enable_x64.value and dtype in [ jnp.int64, diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 38fde72f0..12088db7f 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -973,6 +973,7 @@ class BCOOTest(sptu.SparseTestCase): self.assertArraysAllClose(out.todense(), expected_out) self.assertEqual(out.nse, expected_nse) + @jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available") def test_bcoo_spdot_general_ad_bug(self): # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])