Set PYTHONWARNINGS=error in bazel tests.

The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
This commit is contained in:
Peter Hawkins 2024-09-24 12:28:32 -07:00 committed by jax authors
parent d58a09faed
commit 70f91db853
25 changed files with 316 additions and 181 deletions

View File

@ -15,7 +15,7 @@
load( load(
"//jaxlib:jax.bzl", "//jaxlib:jax.bzl",
"jax_generate_backend_suites", "jax_generate_backend_suites",
"jax_test", "jax_multiplatform_test",
"py_deps", "py_deps",
) )
@ -42,7 +42,7 @@ DISABLED_CONFIGS = [
"gpu_pjrt_c_api", "gpu_pjrt_c_api",
] ]
jax_test( jax_multiplatform_test(
name = "matmul_bench", name = "matmul_bench",
srcs = ["matmul_bench.py"], srcs = ["matmul_bench.py"],
disable_backends = DISABLED_BACKENDS, disable_backends = DISABLED_BACKENDS,

View File

@ -16,7 +16,7 @@ load(
"//jaxlib:jax.bzl", "//jaxlib:jax.bzl",
"cuda_library", "cuda_library",
"jax_generate_backend_suites", "jax_generate_backend_suites",
"jax_test", "jax_multiplatform_test",
) )
licenses(["notice"]) licenses(["notice"])
@ -28,7 +28,7 @@ package(
jax_generate_backend_suites() jax_generate_backend_suites()
jax_test( jax_multiplatform_test(
name = "cuda_custom_call_test", name = "cuda_custom_call_test",
srcs = ["cuda_custom_call_test.py"], srcs = ["cuda_custom_call_test.py"],
data = [":foo"], data = [":foo"],

View File

@ -18,13 +18,16 @@ Includes the flags from saved_model_main.py.
See README.md. See README.md.
""" """
import logging import logging
import warnings
from absl import app from absl import app
from absl import flags from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib from jax.experimental.jax2tf.examples import mnist_lib
from jax.experimental.jax2tf.examples import saved_model_main from jax.experimental.jax2tf.examples import saved_model_main
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds # type: ignore import tensorflow_datasets as tfds # type: ignore
import tensorflow_hub as hub # type: ignore with warnings.catch_warnings():
warnings.simplefilter("ignore")
import tensorflow_hub as hub # type: ignore
FLAGS = flags.FLAGS FLAGS = flags.FLAGS

View File

@ -41,6 +41,7 @@ class KerasReuseMainTest(tf_test_util.JaxToTfTestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
dict(testcase_name=f"_{model}", model=model) dict(testcase_name=f"_{model}", model=model)
for model in ["mnist_pure_jax", "mnist_flax"]) for model in ["mnist_pure_jax", "mnist_flax"])
@jtu.ignore_warning(message="the imp module is deprecated")
def test_keras_reuse(self, model="mnist_pure_jax"): def test_keras_reuse(self, model="mnist_pure_jax"):
FLAGS.model = model FLAGS.model = model
keras_reuse_main.main(None) keras_reuse_main.main(None)

View File

@ -27,6 +27,7 @@ import logging
import re import re
import time import time
from typing import Any from typing import Any
import warnings
from absl import flags from absl import flags
import flax import flax
@ -70,7 +71,9 @@ def load_mnist(split: tfds.Split, batch_size: int):
if _MOCK_DATA.value: if _MOCK_DATA.value:
with tfds.testing.mock_data(num_examples=batch_size): with tfds.testing.mock_data(num_examples=batch_size):
try: try:
ds = tfds.load("mnist", split=split) with warnings.catch_warnings():
warnings.simplefilter("ignore")
ds = tfds.load("mnist", split=split)
except Exception as e: except Exception as e:
m = re.search(r'metadata files were not found in (.+/)mnist/', str(e)) m = re.search(r'metadata files were not found in (.+/)mnist/', str(e))
if m: if m:

View File

@ -88,6 +88,17 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
# bug in TensorFlow. # bug in TensorFlow.
_ = tf.add(1, 1) _ = tf.add(1, 1)
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
@_parameterized_jit @_parameterized_jit
def test_eval_scalar_arg(self, with_jit=True): def test_eval_scalar_arg(self, with_jit=True):
@ -862,6 +873,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase): class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
"Reloading output of jax2tf into JAX with call_tf" "Reloading output of jax2tf into JAX with call_tf"
def setUp(self): def setUp(self):
if tf is None: if tf is None:
raise unittest.SkipTest("Test requires tensorflow") raise unittest.SkipTest("Test requires tensorflow")
@ -869,6 +881,17 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
# bug in TensorFlow. # bug in TensorFlow.
_ = tf.add(1, 1) _ = tf.add(1, 1)
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_simple(self): def test_simple(self):
f_jax = jnp.sin f_jax = jnp.sin
@ -1157,6 +1180,17 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
# bug in TensorFlow. # bug in TensorFlow.
_ = tf.add(1, 1) _ = tf.add(1, 1)
super().setUp() super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_alternate(self): def test_alternate(self):
# Alternate sin/cos with sin in TF and cos in JAX # Alternate sin/cos with sin in TF and cos in JAX

View File

@ -76,6 +76,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
super().setUpClass() super().setUpClass()
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_empty(self): def test_empty(self):
f_jax = lambda x, y: x f_jax = lambda x, y: x
self.ConvertAndCompare(f_jax, 0.7, 1) self.ConvertAndCompare(f_jax, 0.7, 1)
@ -1621,6 +1632,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args) res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
self.assertAllClose(f_jax(*many_args), res) self.assertAllClose(f_jax(*many_args), res)
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def test_nested_convert(self): def test_nested_convert(self):
# Test call sequence: convert -> call_tf -> convert. # Test call sequence: convert -> call_tf -> convert.
@ -1677,6 +1690,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@jtu.with_config(jax_enable_custom_prng=True) @jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_key_argument(self): def test_key_argument(self):
func = lambda key: jax.random.uniform(key, ()) func = lambda key: jax.random.uniform(key, ())
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
@ -1709,6 +1733,9 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
self.use_max_serialization_version = False self.use_max_serialization_version = False
super().setUp() super().setUp()
@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
def test_simple(self): def test_simple(self):
self.ConvertAndCompare(jnp.sin, 0.7) self.ConvertAndCompare(jnp.sin, 0.7)

View File

@ -30,6 +30,17 @@ jax.config.parse_flags_with_absl()
class SavedModelTest(tf_test_util.JaxToTfTestCase): class SavedModelTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_eval(self): def test_eval(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
model = tf.Module() model = tf.Module()

View File

@ -334,7 +334,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
check_shape_poly(self, f_jax, arg_descriptors=[x], check_shape_poly(self, f_jax, arg_descriptors=[x],
polymorphic_shapes=["b"]) polymorphic_shapes=["b"])
@jtu.parameterized_filterable( @jtu.parameterized_filterable(
kwargs=[ kwargs=[
dict(testcase_name=f"expr={name}", expr=expr) dict(testcase_name=f"expr={name}", expr=expr)
@ -941,7 +940,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
xi_yf = (xi, yf) xi_yf = (xi, yf)
zb = np.array([True, False], dtype=np.bool_) zb = np.array([True, False], dtype=np.bool_)
def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2] def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4] # results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
xi, yf = xi_yf xi, yf = xi_yf
# Return a tuple: # Return a tuple:
# (1) float constant, with 0 tangent; # (1) float constant, with 0 tangent;
@ -1032,6 +1031,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)]) f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
self.assertAllClose(f_jax(x), restored_f(x)) self.assertAllClose(f_jax(x), restored_f(x))
@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
def test_readme_examples(self): def test_readme_examples(self):
"""Some of the examples from the README.""" """Some of the examples from the README."""

View File

@ -61,7 +61,8 @@ def setUpModule():
global topology global topology
if jtu.test_device_matches(["tpu"]): if jtu.test_device_matches(["tpu"]):
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') with jtu.ignore_warning(message="the imp module is deprecated"):
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver) tf.config.experimental_connect_to_cluster(resolver)
# Do TPU init at beginning since it will wipe out all HBMs. # Do TPU init at beginning since it will wipe out all HBMs.
topology = tf.tpu.experimental.initialize_tpu_system(resolver) topology = tf.tpu.experimental.initialize_tpu_system(resolver)
@ -84,6 +85,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
raise unittest.SkipTest("Test requires at least 2 local devices") raise unittest.SkipTest("Test requires at least 2 local devices")
self.devices = np.array(jax.devices()[:2]) # use 2 devices self.devices = np.array(jax.devices()[:2]) # use 2 devices
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def log_jax_hlo(self, f_jax, args: Sequence[Any], *, def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
num_replicas=1, num_partitions=2): num_replicas=1, num_partitions=2):
"""Log the HLO generated from JAX before and after optimizations""" """Log the HLO generated from JAX before and after optimizations"""

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "py_deps") load("//jaxlib:jax.bzl", "jax_py_test", "py_deps")
licenses(["notice"]) licenses(["notice"])
@ -48,7 +48,7 @@ py_library(
], ],
) )
py_test( jax_py_test(
name = "run_matmul", name = "run_matmul",
srcs = ["matmul.py"], srcs = ["matmul.py"],
main = "matmul.py", main = "matmul.py",

View File

@ -19,6 +19,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library",
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
load("@rules_python//python:defs.bzl", "py_test")
load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
@ -222,7 +223,7 @@ def if_building_jaxlib(
}) })
# buildifier: disable=function-docstring # buildifier: disable=function-docstring
def jax_test( def jax_multiplatform_test(
name, name,
srcs, srcs,
args = [], args = [],
@ -300,3 +301,12 @@ jax_test_file_visibility = []
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
pass pass
def jax_py_test(
name,
env = {},
**kwargs):
env = dict(env)
if "PYTHONWARNINGS" not in env:
env["PYTHONWARNINGS"] = "error"
py_test(name = name, env = env, **kwargs)

View File

@ -16,7 +16,7 @@
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("//jaxlib:jax.bzl", "if_windows") load("//jaxlib:jax.bzl", "if_windows", "jax_py_test")
licenses(["notice"]) # Apache 2 licenses(["notice"]) # Apache 2
@ -52,7 +52,7 @@ py_binary(
], ],
) )
py_test( jax_py_test(
name = "build_wheel_test", name = "build_wheel_test",
srcs = ["build_wheel_test.py"], srcs = ["build_wheel_test.py"],
data = [":build_wheel"], data = [":build_wheel"],

View File

@ -60,6 +60,11 @@ filterwarnings = [
"default:Special cases found for .* but none were parsed.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning",
"default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning",
"default:The .* method is good for exploring strategies.*", "default:The .* method is good for exploring strategies.*",
# NOTE: this is probably not where you want to add code to suppress a
# warning. Only pytest tests look at this list, whereas Bazel tests also
# check for warnings and do not check this list. Most likely, you should
# add a @jtu.ignore_warning decorator to your test instead.
] ]
doctest_optionflags = [ doctest_optionflags = [
"NUMBER", "NUMBER",

File diff suppressed because it is too large Load Diff

View File

@ -75,6 +75,8 @@ class DLPackTest(jtu.JaxTestCase):
use_stream=[False, True], use_stream=[False, True],
) )
@jtu.run_on_devices("gpu") @jtu.run_on_devices("gpu")
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testJaxRoundTrip(self, shape, dtype, copy, use_stream): def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
np = rng(shape, dtype) np = rng(shape, dtype)
@ -142,6 +144,8 @@ class DLPackTest(jtu.JaxTestCase):
dtype=dlpack_dtypes, dtype=dlpack_dtypes,
) )
@unittest.skipIf(not tf, "Test requires TensorFlow") @unittest.skipIf(not tf, "Test requires TensorFlow")
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTensorFlowToJax(self, shape, dtype): def testTensorFlowToJax(self, shape, dtype):
if (not config.enable_x64.value and if (not config.enable_x64.value and
dtype in [jnp.int64, jnp.uint64, jnp.float64]): dtype in [jnp.int64, jnp.uint64, jnp.float64]):
@ -184,6 +188,8 @@ class DLPackTest(jtu.JaxTestCase):
self.assertAllClose(np, y.numpy()) self.assertAllClose(np, y.numpy())
@unittest.skipIf(not tf, "Test requires TensorFlow") @unittest.skipIf(not tf, "Test requires TensorFlow")
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTensorFlowToJaxInt64(self): def testTensorFlowToJaxInt64(self):
# See https://github.com/jax-ml/jax/issues/11895 # See https://github.com/jax-ml/jax/issues/11895
x = jax.dlpack.from_dlpack( x = jax.dlpack.from_dlpack(

View File

@ -176,6 +176,8 @@ class CallToTFTest(jtu.JaxTestCase):
testcase_name=f"_{ad=}", testcase_name=f"_{ad=}",
ad=ad) ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys()) for ad in CALL_TF_IMPLEMENTATIONS.keys())
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_impl(self, ad="simple"): def test_impl(self, ad="simple"):
self.supported_only_in_legacy_mode() self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad] call_tf = CALL_TF_IMPLEMENTATIONS[ad]
@ -197,6 +199,8 @@ class CallToTFTest(jtu.JaxTestCase):
ad=ad) ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys() for ad in CALL_TF_IMPLEMENTATIONS.keys()
if ad != "none") if ad != "none")
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_grad(self, ad="simple"): def test_grad(self, ad="simple"):
self.supported_only_in_legacy_mode() self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad] call_tf = CALL_TF_IMPLEMENTATIONS[ad]
@ -217,6 +221,8 @@ class CallToTFTest(jtu.JaxTestCase):
self.assertAllClose(jax.grad(f_jax)(x), grad_f, self.assertAllClose(jax.grad(f_jax)(x), grad_f,
check_dtypes=False) check_dtypes=False)
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_grad_pytree(self): def test_grad_pytree(self):
self.supported_only_in_legacy_mode() self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad call_tf = call_tf_full_ad
@ -246,6 +252,8 @@ class CallToTFTest(jtu.JaxTestCase):
testcase_name=f"_degree=_{degree}", testcase_name=f"_degree=_{degree}",
degree=degree) degree=degree)
for degree in [1, 2, 3, 4]) for degree in [1, 2, 3, 4])
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_higher_order_grad(self, degree=4): def test_higher_order_grad(self, degree=4):
self.supported_only_in_legacy_mode() self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad call_tf = call_tf_full_ad

View File

@ -2681,6 +2681,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
shape=all_shapes, shape=all_shapes,
dtype=default_dtypes, dtype=default_dtypes,
) )
@jtu.ignore_warning(category=RuntimeWarning, message="overflow")
def testFrexp(self, shape, dtype, rng_factory): def testFrexp(self, shape, dtype, rng_factory):
# integer types are converted to float64 in numpy's implementation # integer types are converted to float64 in numpy's implementation
if (dtype not in [jnp.bfloat16, np.float16, np.float32] if (dtype not in [jnp.bfloat16, np.float16, np.float32]
@ -6270,7 +6271,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
try: try:
with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"): with jtu.ignore_warning(
category=RuntimeWarning, message="(divide by zero|invalid value)"):
_ = func(*args) _ = func(*args)
except TypeError: except TypeError:
pass pass
@ -6292,7 +6294,7 @@ class NumpyUfuncTests(jtu.JaxTestCase):
jnp_op = getattr(jnp, name) jnp_op = getattr(jnp, name)
np_op = getattr(np, name) np_op = getattr(np, name)
np_op = jtu.ignore_warning(category=RuntimeWarning, np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op) message="(divide by zero|invalid value)")(np_op)
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes) args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
with jtu.strict_promotion_if_dtypes_match(arg_dtypes): with jtu.strict_promotion_if_dtypes_match(arg_dtypes):

View File

@ -110,6 +110,7 @@ class LaxTest(jtu.JaxTestCase):
for shape_group in lax_test_util.compatible_shapes), for shape_group in lax_test_util.compatible_shapes),
dtype=rec.dtypes) dtype=rec.dtypes)
for rec in lax_test_util.lax_ops())) for rec in lax_test_util.lax_ops()))
@jtu.ignore_warning(message="invalid value", category=RuntimeWarning)
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol): def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
if (not config.enable_x64.value and op_name == "nextafter" if (not config.enable_x64.value and op_name == "nextafter"
and dtype == np.float64): and dtype == np.float64):

View File

@ -15,7 +15,7 @@
load( load(
"//jaxlib:jax.bzl", "//jaxlib:jax.bzl",
"jax_generate_backend_suites", "jax_generate_backend_suites",
"jax_test", "jax_multiplatform_test",
"py_deps", "py_deps",
) )
@ -43,7 +43,7 @@ DISABLED_CONFIGS = [
"gpu", "gpu",
] ]
jax_test( jax_multiplatform_test(
name = "gpu_test", name = "gpu_test",
srcs = ["gpu_test.py"], srcs = ["gpu_test.py"],
config_tags_overrides = { config_tags_overrides = {
@ -63,7 +63,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "matmul_test", name = "matmul_test",
srcs = ["matmul_test.py"], srcs = ["matmul_test.py"],
disable_backends = DISABLED_BACKENDS, disable_backends = DISABLED_BACKENDS,
@ -75,7 +75,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
) )
jax_test( jax_multiplatform_test(
name = "flash_attention", name = "flash_attention",
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
disable_backends = DISABLED_BACKENDS, disable_backends = DISABLED_BACKENDS,
@ -87,7 +87,7 @@ jax_test(
] + py_deps("numpy"), ] + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "flash_attention_test", name = "flash_attention_test",
srcs = ["flash_attention_test.py"], srcs = ["flash_attention_test.py"],
disable_backends = DISABLED_BACKENDS, disable_backends = DISABLED_BACKENDS,

View File

@ -1200,6 +1200,8 @@ class FragmentedArrayTest(TestCase):
m=(64, 128), m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256), n=(8, 16, 32, 64, 80, 128, 256),
) )
@jtu.ignore_warning(message="(invalid value|divide by zero)",
category=RuntimeWarning)
def test_binary(self, op, dtype, m=64, n=32): def test_binary(self, op, dtype, m=64, n=32):
if isinstance(op, tuple): if isinstance(op, tuple):
op, np_op = op op, np_op = op
@ -1294,6 +1296,7 @@ class FragmentedArrayTest(TestCase):
], ],
approx=[False, True], approx=[False, True],
) )
@jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning)
def test_math(self, ops, approx, m=64, n=32): def test_math(self, ops, approx, m=64, n=32):
op, np_op = ops op, np_op = ops
def kernel(ctx, dst, _): def kernel(ctx, dst, _):

View File

@ -15,7 +15,7 @@
load( load(
"//jaxlib:jax.bzl", "//jaxlib:jax.bzl",
"jax_generate_backend_suites", "jax_generate_backend_suites",
"jax_test", "jax_multiplatform_test",
"py_deps", "py_deps",
) )
@ -28,7 +28,7 @@ package(
jax_generate_backend_suites() jax_generate_backend_suites()
jax_test( jax_multiplatform_test(
name = "pallas_test", name = "pallas_test",
srcs = [ srcs = [
"pallas_test.py", "pallas_test.py",
@ -62,7 +62,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "pallas_jumble_test", name = "pallas_jumble_test",
srcs = [ srcs = [
"pallas_jumble_test.py", "pallas_jumble_test.py",
@ -85,7 +85,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "ops_test", name = "ops_test",
srcs = [ srcs = [
"ops_test.py", "ops_test.py",
@ -125,7 +125,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "indexing_test", name = "indexing_test",
srcs = [ srcs = [
"indexing_test.py", "indexing_test.py",
@ -144,7 +144,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "pallas_vmap_test", name = "pallas_vmap_test",
srcs = [ srcs = [
"pallas_vmap_test.py", "pallas_vmap_test.py",
@ -176,7 +176,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "mosaic_gpu_test", name = "mosaic_gpu_test",
srcs = [ srcs = [
"mosaic_gpu_test.py", "mosaic_gpu_test.py",
@ -213,7 +213,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "export_back_compat_pallas_test", name = "export_back_compat_pallas_test",
srcs = ["export_back_compat_pallas_test.py"], srcs = ["export_back_compat_pallas_test.py"],
config_tags_overrides = { config_tags_overrides = {
@ -244,7 +244,7 @@ jax_test(
], ],
) )
jax_test( jax_multiplatform_test(
name = "export_pallas_test", name = "export_pallas_test",
srcs = ["export_pallas_test.py"], srcs = ["export_pallas_test.py"],
config_tags_overrides = { config_tags_overrides = {
@ -272,7 +272,7 @@ jax_test(
], ],
) )
jax_test( jax_multiplatform_test(
name = "pallas_shape_poly_test", name = "pallas_shape_poly_test",
srcs = ["pallas_shape_poly_test.py"], srcs = ["pallas_shape_poly_test.py"],
config_tags_overrides = { config_tags_overrides = {
@ -299,7 +299,7 @@ jax_test(
], ],
) )
jax_test( jax_multiplatform_test(
name = "pallas_error_handling_test", name = "pallas_error_handling_test",
srcs = [ srcs = [
"pallas_error_handling_test.py", "pallas_error_handling_test.py",
@ -317,7 +317,7 @@ jax_test(
] + py_deps("numpy"), ] + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "tpu_all_gather_test", name = "tpu_all_gather_test",
srcs = [ srcs = [
"tpu_all_gather_test.py", "tpu_all_gather_test.py",
@ -331,7 +331,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
) )
jax_test( jax_multiplatform_test(
name = "tpu_gmm_test", name = "tpu_gmm_test",
srcs = [ srcs = [
"tpu_gmm_test.py", "tpu_gmm_test.py",
@ -356,7 +356,7 @@ jax_test(
]), ]),
) )
jax_test( jax_multiplatform_test(
name = "tpu_pallas_test", name = "tpu_pallas_test",
srcs = ["tpu_pallas_test.py"], srcs = ["tpu_pallas_test.py"],
# The flag is necessary for ``pl.debug_print`` tests to work on TPU. # The flag is necessary for ``pl.debug_print`` tests to work on TPU.
@ -372,7 +372,7 @@ jax_test(
], ],
) )
jax_test( jax_multiplatform_test(
name = "tpu_ops_test", name = "tpu_ops_test",
srcs = [ srcs = [
"tpu_ops_test.py", "tpu_ops_test.py",
@ -388,7 +388,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "tpu_pallas_distributed_test", name = "tpu_pallas_distributed_test",
srcs = ["tpu_pallas_distributed_test.py"], srcs = ["tpu_pallas_distributed_test.py"],
disable_backends = [ disable_backends = [
@ -402,7 +402,7 @@ jax_test(
], ],
) )
jax_test( jax_multiplatform_test(
name = "tpu_pallas_pipeline_test", name = "tpu_pallas_pipeline_test",
srcs = ["tpu_pallas_pipeline_test.py"], srcs = ["tpu_pallas_pipeline_test.py"],
disable_backends = [ disable_backends = [
@ -422,7 +422,7 @@ jax_test(
] + py_deps("hypothesis"), ] + py_deps("hypothesis"),
) )
jax_test( jax_multiplatform_test(
name = "tpu_pallas_async_test", name = "tpu_pallas_async_test",
srcs = ["tpu_pallas_async_test.py"], srcs = ["tpu_pallas_async_test.py"],
disable_backends = [ disable_backends = [
@ -436,7 +436,7 @@ jax_test(
], ],
) )
jax_test( jax_multiplatform_test(
name = "tpu_pallas_mesh_test", name = "tpu_pallas_mesh_test",
srcs = ["tpu_pallas_mesh_test.py"], srcs = ["tpu_pallas_mesh_test.py"],
disable_backends = [ disable_backends = [
@ -454,7 +454,7 @@ jax_test(
], ],
) )
jax_test( jax_multiplatform_test(
name = "tpu_pallas_random_test", name = "tpu_pallas_random_test",
srcs = [ srcs = [
"tpu_pallas_random_test.py", "tpu_pallas_random_test.py",
@ -472,7 +472,7 @@ jax_test(
] + py_deps("numpy"), ] + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "tpu_paged_attention_kernel_test", name = "tpu_paged_attention_kernel_test",
srcs = ["tpu_paged_attention_kernel_test.py"], srcs = ["tpu_paged_attention_kernel_test.py"],
disable_backends = [ disable_backends = [
@ -490,7 +490,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "tpu_splash_attention_kernel_test", name = "tpu_splash_attention_kernel_test",
srcs = [ srcs = [
"tpu_splash_attention_kernel_test.py", "tpu_splash_attention_kernel_test.py",
@ -510,7 +510,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
) )
jax_test( jax_multiplatform_test(
name = "tpu_splash_attention_mask_test", name = "tpu_splash_attention_mask_test",
srcs = [ srcs = [
"tpu_splash_attention_mask_test.py", "tpu_splash_attention_mask_test.py",
@ -523,7 +523,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
) )
jax_test( jax_multiplatform_test(
name = "gpu_attention_test", name = "gpu_attention_test",
srcs = [ srcs = [
"gpu_attention_test.py", "gpu_attention_test.py",
@ -556,7 +556,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"), ] + py_deps("absl/testing") + py_deps("numpy"),
) )
jax_test( jax_multiplatform_test(
name = "gpu_ops_test", name = "gpu_ops_test",
srcs = [ srcs = [
"gpu_ops_test.py", "gpu_ops_test.py",

View File

@ -3209,8 +3209,12 @@ class EagerPmapMixin:
self.jit_disabled = config.disable_jit.value self.jit_disabled = config.disable_jit.value
config.update('jax_disable_jit', True) config.update('jax_disable_jit', True)
config.update('jax_eager_pmap', True) config.update('jax_eager_pmap', True)
self.warning_ctx = jtu.ignore_warning(
message="Some donated buffers were not usable", category=UserWarning)
self.warning_ctx.__enter__()
def tearDown(self): def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
config.update('jax_eager_pmap', self.eager_pmap_enabled) config.update('jax_eager_pmap', self.eager_pmap_enabled)
config.update('jax_disable_jit', self.jit_disabled) config.update('jax_disable_jit', self.jit_disabled)
super().tearDown() super().tearDown()

View File

@ -108,6 +108,8 @@ class DLPackTest(jtu.JaxTestCase):
else: else:
self.assertAllClose(np, y.cpu().numpy()) self.assertAllClose(np, y.cpu().numpy())
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTorchToJaxInt64(self): def testTorchToJaxInt64(self):
# See https://github.com/jax-ml/jax/issues/11895 # See https://github.com/jax-ml/jax/issues/11895
x = jax.dlpack.from_dlpack( x = jax.dlpack.from_dlpack(
@ -116,6 +118,8 @@ class DLPackTest(jtu.JaxTestCase):
self.assertEqual(x.dtype, dtype_expected) self.assertEqual(x.dtype, dtype_expected)
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTorchToJax(self, shape, dtype): def testTorchToJax(self, shape, dtype):
if not config.enable_x64.value and dtype in [ if not config.enable_x64.value and dtype in [
jnp.int64, jnp.int64,

View File

@ -973,6 +973,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertArraysAllClose(out.todense(), expected_out) self.assertArraysAllClose(out.todense(), expected_out)
self.assertEqual(out.nse, expected_nse) self.assertEqual(out.nse, expected_nse)
@jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available")
def test_bcoo_spdot_general_ad_bug(self): def test_bcoo_spdot_general_ad_bug(self):
# Regression test for https://github.com/jax-ml/jax/issues/10163 # Regression test for https://github.com/jax-ml/jax/issues/10163
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])