Set PYTHONWARNINGS=error in bazel tests.

The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
This commit is contained in:
Peter Hawkins 2024-09-24 12:28:32 -07:00 committed by jax authors
parent d58a09faed
commit 70f91db853
25 changed files with 316 additions and 181 deletions

View File

@ -15,7 +15,7 @@
load(
"//jaxlib:jax.bzl",
"jax_generate_backend_suites",
"jax_test",
"jax_multiplatform_test",
"py_deps",
)
@ -42,7 +42,7 @@ DISABLED_CONFIGS = [
"gpu_pjrt_c_api",
]
jax_test(
jax_multiplatform_test(
name = "matmul_bench",
srcs = ["matmul_bench.py"],
disable_backends = DISABLED_BACKENDS,

View File

@ -16,7 +16,7 @@ load(
"//jaxlib:jax.bzl",
"cuda_library",
"jax_generate_backend_suites",
"jax_test",
"jax_multiplatform_test",
)
licenses(["notice"])
@ -28,7 +28,7 @@ package(
jax_generate_backend_suites()
jax_test(
jax_multiplatform_test(
name = "cuda_custom_call_test",
srcs = ["cuda_custom_call_test.py"],
data = [":foo"],

View File

@ -18,13 +18,16 @@ Includes the flags from saved_model_main.py.
See README.md.
"""
import logging
import warnings
from absl import app
from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib
from jax.experimental.jax2tf.examples import saved_model_main
import tensorflow as tf
import tensorflow_datasets as tfds # type: ignore
import tensorflow_hub as hub # type: ignore
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import tensorflow_hub as hub # type: ignore
FLAGS = flags.FLAGS

View File

@ -41,6 +41,7 @@ class KerasReuseMainTest(tf_test_util.JaxToTfTestCase):
@parameterized.named_parameters(
dict(testcase_name=f"_{model}", model=model)
for model in ["mnist_pure_jax", "mnist_flax"])
@jtu.ignore_warning(message="the imp module is deprecated")
def test_keras_reuse(self, model="mnist_pure_jax"):
FLAGS.model = model
keras_reuse_main.main(None)

View File

@ -27,6 +27,7 @@ import logging
import re
import time
from typing import Any
import warnings
from absl import flags
import flax
@ -70,7 +71,9 @@ def load_mnist(split: tfds.Split, batch_size: int):
if _MOCK_DATA.value:
with tfds.testing.mock_data(num_examples=batch_size):
try:
ds = tfds.load("mnist", split=split)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ds = tfds.load("mnist", split=split)
except Exception as e:
m = re.search(r'metadata files were not found in (.+/)mnist/', str(e))
if m:

View File

@ -88,6 +88,17 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
# bug in TensorFlow.
_ = tf.add(1, 1)
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
@_parameterized_jit
def test_eval_scalar_arg(self, with_jit=True):
@ -862,6 +873,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
"Reloading output of jax2tf into JAX with call_tf"
def setUp(self):
if tf is None:
raise unittest.SkipTest("Test requires tensorflow")
@ -869,6 +881,17 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
# bug in TensorFlow.
_ = tf.add(1, 1)
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_simple(self):
f_jax = jnp.sin
@ -1157,6 +1180,17 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
# bug in TensorFlow.
_ = tf.add(1, 1)
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message=(
"(jax2tf.convert with native_serialization=False is deprecated"
"|Calling from_dlpack with a DLPack tensor is deprecated)"
)
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_alternate(self):
# Alternate sin/cos with sin in TF and cos in JAX

View File

@ -76,6 +76,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
super().setUpClass()
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_empty(self):
f_jax = lambda x, y: x
self.ConvertAndCompare(f_jax, 0.7, 1)
@ -1621,6 +1632,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
self.assertAllClose(f_jax(*many_args), res)
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def test_nested_convert(self):
# Test call sequence: convert -> call_tf -> convert.
@ -1677,6 +1690,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@jtu.with_config(jax_enable_custom_prng=True)
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_key_argument(self):
func = lambda key: jax.random.uniform(key, ())
key = jax.random.PRNGKey(0)
@ -1709,6 +1733,9 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
self.use_max_serialization_version = False
super().setUp()
@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
def test_simple(self):
self.ConvertAndCompare(jnp.sin, 0.7)

View File

@ -30,6 +30,17 @@ jax.config.parse_flags_with_absl()
class SavedModelTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
super().setUp()
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def test_eval(self):
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
model = tf.Module()

View File

@ -334,7 +334,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
check_shape_poly(self, f_jax, arg_descriptors=[x],
polymorphic_shapes=["b"])
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"expr={name}", expr=expr)
@ -941,7 +940,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
xi_yf = (xi, yf)
zb = np.array([True, False], dtype=np.bool_)
def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
xi, yf = xi_yf
# Return a tuple:
# (1) float constant, with 0 tangent;
@ -1032,6 +1031,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
self.assertAllClose(f_jax(x), restored_f(x))
@jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
def test_readme_examples(self):
"""Some of the examples from the README."""

View File

@ -61,7 +61,8 @@ def setUpModule():
global topology
if jtu.test_device_matches(["tpu"]):
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
with jtu.ignore_warning(message="the imp module is deprecated"):
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
# Do TPU init at beginning since it will wipe out all HBMs.
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
@ -84,6 +85,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
raise unittest.SkipTest("Test requires at least 2 local devices")
self.devices = np.array(jax.devices()[:2]) # use 2 devices
self.warning_ctx = jtu.ignore_warning(
message="jax2tf.convert with native_serialization=False is deprecated"
)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
super().tearDown()
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
num_replicas=1, num_partitions=2):
"""Log the HLO generated from JAX before and after optimizations"""

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("//jaxlib:jax.bzl", "py_deps")
load("@rules_python//python:defs.bzl", "py_library")
load("//jaxlib:jax.bzl", "jax_py_test", "py_deps")
licenses(["notice"])
@ -48,7 +48,7 @@ py_library(
],
)
py_test(
jax_py_test(
name = "run_matmul",
srcs = ["matmul.py"],
main = "matmul.py",

View File

@ -19,6 +19,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library",
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
load("@rules_python//python:defs.bzl", "py_test")
load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
@ -222,7 +223,7 @@ def if_building_jaxlib(
})
# buildifier: disable=function-docstring
def jax_test(
def jax_multiplatform_test(
name,
srcs,
args = [],
@ -300,3 +301,12 @@ jax_test_file_visibility = []
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
pass
def jax_py_test(
name,
env = {},
**kwargs):
env = dict(env)
if "PYTHONWARNINGS" not in env:
env["PYTHONWARNINGS"] = "error"
py_test(name = name, env = env, **kwargs)

View File

@ -16,7 +16,7 @@
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("//jaxlib:jax.bzl", "if_windows")
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test")
licenses(["notice"]) # Apache 2
@ -52,7 +52,7 @@ py_binary(
],
)
py_test(
jax_py_test(
name = "build_wheel_test",
srcs = ["build_wheel_test.py"],
data = [":build_wheel"],

View File

@ -60,6 +60,11 @@ filterwarnings = [
"default:Special cases found for .* but none were parsed.*:UserWarning",
"default:.*is not JSON-serializable. Using the repr instead.*:UserWarning",
"default:The .* method is good for exploring strategies.*",
# NOTE: this is probably not where you want to add code to suppress a
# warning. Only pytest tests look at this list, whereas Bazel tests also
# check for warnings and do not check this list. Most likely, you should
# add a @jtu.ignore_warning decorator to your test instead.
]
doctest_optionflags = [
"NUMBER",

File diff suppressed because it is too large Load Diff

View File

@ -75,6 +75,8 @@ class DLPackTest(jtu.JaxTestCase):
use_stream=[False, True],
)
@jtu.run_on_devices("gpu")
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
@ -142,6 +144,8 @@ class DLPackTest(jtu.JaxTestCase):
dtype=dlpack_dtypes,
)
@unittest.skipIf(not tf, "Test requires TensorFlow")
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTensorFlowToJax(self, shape, dtype):
if (not config.enable_x64.value and
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
@ -184,6 +188,8 @@ class DLPackTest(jtu.JaxTestCase):
self.assertAllClose(np, y.numpy())
@unittest.skipIf(not tf, "Test requires TensorFlow")
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTensorFlowToJaxInt64(self):
# See https://github.com/jax-ml/jax/issues/11895
x = jax.dlpack.from_dlpack(

View File

@ -176,6 +176,8 @@ class CallToTFTest(jtu.JaxTestCase):
testcase_name=f"_{ad=}",
ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys())
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_impl(self, ad="simple"):
self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
@ -197,6 +199,8 @@ class CallToTFTest(jtu.JaxTestCase):
ad=ad)
for ad in CALL_TF_IMPLEMENTATIONS.keys()
if ad != "none")
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_grad(self, ad="simple"):
self.supported_only_in_legacy_mode()
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
@ -217,6 +221,8 @@ class CallToTFTest(jtu.JaxTestCase):
self.assertAllClose(jax.grad(f_jax)(x), grad_f,
check_dtypes=False)
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_grad_pytree(self):
self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad
@ -246,6 +252,8 @@ class CallToTFTest(jtu.JaxTestCase):
testcase_name=f"_degree=_{degree}",
degree=degree)
for degree in [1, 2, 3, 4])
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
category=DeprecationWarning)
def test_higher_order_grad(self, degree=4):
self.supported_only_in_legacy_mode()
call_tf = call_tf_full_ad

View File

@ -2681,6 +2681,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
shape=all_shapes,
dtype=default_dtypes,
)
@jtu.ignore_warning(category=RuntimeWarning, message="overflow")
def testFrexp(self, shape, dtype, rng_factory):
# integer types are converted to float64 in numpy's implementation
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
@ -6270,7 +6271,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
try:
with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"):
with jtu.ignore_warning(
category=RuntimeWarning, message="(divide by zero|invalid value)"):
_ = func(*args)
except TypeError:
pass
@ -6292,7 +6294,7 @@ class NumpyUfuncTests(jtu.JaxTestCase):
jnp_op = getattr(jnp, name)
np_op = getattr(np, name)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="divide by zero.*")(np_op)
message="(divide by zero|invalid value)")(np_op)
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):

View File

@ -110,6 +110,7 @@ class LaxTest(jtu.JaxTestCase):
for shape_group in lax_test_util.compatible_shapes),
dtype=rec.dtypes)
for rec in lax_test_util.lax_ops()))
@jtu.ignore_warning(message="invalid value", category=RuntimeWarning)
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
if (not config.enable_x64.value and op_name == "nextafter"
and dtype == np.float64):

View File

@ -15,7 +15,7 @@
load(
"//jaxlib:jax.bzl",
"jax_generate_backend_suites",
"jax_test",
"jax_multiplatform_test",
"py_deps",
)
@ -43,7 +43,7 @@ DISABLED_CONFIGS = [
"gpu",
]
jax_test(
jax_multiplatform_test(
name = "gpu_test",
srcs = ["gpu_test.py"],
config_tags_overrides = {
@ -63,7 +63,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "matmul_test",
srcs = ["matmul_test.py"],
disable_backends = DISABLED_BACKENDS,
@ -75,7 +75,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
)
jax_test(
jax_multiplatform_test(
name = "flash_attention",
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
disable_backends = DISABLED_BACKENDS,
@ -87,7 +87,7 @@ jax_test(
] + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "flash_attention_test",
srcs = ["flash_attention_test.py"],
disable_backends = DISABLED_BACKENDS,

View File

@ -1200,6 +1200,8 @@ class FragmentedArrayTest(TestCase):
m=(64, 128),
n=(8, 16, 32, 64, 80, 128, 256),
)
@jtu.ignore_warning(message="(invalid value|divide by zero)",
category=RuntimeWarning)
def test_binary(self, op, dtype, m=64, n=32):
if isinstance(op, tuple):
op, np_op = op
@ -1294,6 +1296,7 @@ class FragmentedArrayTest(TestCase):
],
approx=[False, True],
)
@jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning)
def test_math(self, ops, approx, m=64, n=32):
op, np_op = ops
def kernel(ctx, dst, _):

View File

@ -15,7 +15,7 @@
load(
"//jaxlib:jax.bzl",
"jax_generate_backend_suites",
"jax_test",
"jax_multiplatform_test",
"py_deps",
)
@ -28,7 +28,7 @@ package(
jax_generate_backend_suites()
jax_test(
jax_multiplatform_test(
name = "pallas_test",
srcs = [
"pallas_test.py",
@ -62,7 +62,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "pallas_jumble_test",
srcs = [
"pallas_jumble_test.py",
@ -85,7 +85,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "ops_test",
srcs = [
"ops_test.py",
@ -125,7 +125,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "indexing_test",
srcs = [
"indexing_test.py",
@ -144,7 +144,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "pallas_vmap_test",
srcs = [
"pallas_vmap_test.py",
@ -176,7 +176,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "mosaic_gpu_test",
srcs = [
"mosaic_gpu_test.py",
@ -213,7 +213,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "export_back_compat_pallas_test",
srcs = ["export_back_compat_pallas_test.py"],
config_tags_overrides = {
@ -244,7 +244,7 @@ jax_test(
],
)
jax_test(
jax_multiplatform_test(
name = "export_pallas_test",
srcs = ["export_pallas_test.py"],
config_tags_overrides = {
@ -272,7 +272,7 @@ jax_test(
],
)
jax_test(
jax_multiplatform_test(
name = "pallas_shape_poly_test",
srcs = ["pallas_shape_poly_test.py"],
config_tags_overrides = {
@ -299,7 +299,7 @@ jax_test(
],
)
jax_test(
jax_multiplatform_test(
name = "pallas_error_handling_test",
srcs = [
"pallas_error_handling_test.py",
@ -317,7 +317,7 @@ jax_test(
] + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "tpu_all_gather_test",
srcs = [
"tpu_all_gather_test.py",
@ -331,7 +331,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
)
jax_test(
jax_multiplatform_test(
name = "tpu_gmm_test",
srcs = [
"tpu_gmm_test.py",
@ -356,7 +356,7 @@ jax_test(
]),
)
jax_test(
jax_multiplatform_test(
name = "tpu_pallas_test",
srcs = ["tpu_pallas_test.py"],
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
@ -372,7 +372,7 @@ jax_test(
],
)
jax_test(
jax_multiplatform_test(
name = "tpu_ops_test",
srcs = [
"tpu_ops_test.py",
@ -388,7 +388,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "tpu_pallas_distributed_test",
srcs = ["tpu_pallas_distributed_test.py"],
disable_backends = [
@ -402,7 +402,7 @@ jax_test(
],
)
jax_test(
jax_multiplatform_test(
name = "tpu_pallas_pipeline_test",
srcs = ["tpu_pallas_pipeline_test.py"],
disable_backends = [
@ -422,7 +422,7 @@ jax_test(
] + py_deps("hypothesis"),
)
jax_test(
jax_multiplatform_test(
name = "tpu_pallas_async_test",
srcs = ["tpu_pallas_async_test.py"],
disable_backends = [
@ -436,7 +436,7 @@ jax_test(
],
)
jax_test(
jax_multiplatform_test(
name = "tpu_pallas_mesh_test",
srcs = ["tpu_pallas_mesh_test.py"],
disable_backends = [
@ -454,7 +454,7 @@ jax_test(
],
)
jax_test(
jax_multiplatform_test(
name = "tpu_pallas_random_test",
srcs = [
"tpu_pallas_random_test.py",
@ -472,7 +472,7 @@ jax_test(
] + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "tpu_paged_attention_kernel_test",
srcs = ["tpu_paged_attention_kernel_test.py"],
disable_backends = [
@ -490,7 +490,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "tpu_splash_attention_kernel_test",
srcs = [
"tpu_splash_attention_kernel_test.py",
@ -510,7 +510,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
)
jax_test(
jax_multiplatform_test(
name = "tpu_splash_attention_mask_test",
srcs = [
"tpu_splash_attention_mask_test.py",
@ -523,7 +523,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
)
jax_test(
jax_multiplatform_test(
name = "gpu_attention_test",
srcs = [
"gpu_attention_test.py",
@ -556,7 +556,7 @@ jax_test(
] + py_deps("absl/testing") + py_deps("numpy"),
)
jax_test(
jax_multiplatform_test(
name = "gpu_ops_test",
srcs = [
"gpu_ops_test.py",

View File

@ -3209,8 +3209,12 @@ class EagerPmapMixin:
self.jit_disabled = config.disable_jit.value
config.update('jax_disable_jit', True)
config.update('jax_eager_pmap', True)
self.warning_ctx = jtu.ignore_warning(
message="Some donated buffers were not usable", category=UserWarning)
self.warning_ctx.__enter__()
def tearDown(self):
self.warning_ctx.__exit__(None, None, None)
config.update('jax_eager_pmap', self.eager_pmap_enabled)
config.update('jax_disable_jit', self.jit_disabled)
super().tearDown()

View File

@ -108,6 +108,8 @@ class DLPackTest(jtu.JaxTestCase):
else:
self.assertAllClose(np, y.cpu().numpy())
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTorchToJaxInt64(self):
# See https://github.com/jax-ml/jax/issues/11895
x = jax.dlpack.from_dlpack(
@ -116,6 +118,8 @@ class DLPackTest(jtu.JaxTestCase):
self.assertEqual(x.dtype, dtype_expected)
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
category=DeprecationWarning)
def testTorchToJax(self, shape, dtype):
if not config.enable_x64.value and dtype in [
jnp.int64,

View File

@ -973,6 +973,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertArraysAllClose(out.todense(), expected_out)
self.assertEqual(out.nse, expected_nse)
@jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available")
def test_bcoo_spdot_general_ad_bug(self):
# Regression test for https://github.com/jax-ml/jax/issues/10163
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])