mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner. To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable. Add code to suppress some new warnings uncovered in CI. PiperOrigin-RevId: 678352286
This commit is contained in:
parent
d58a09faed
commit
70f91db853
@ -15,7 +15,7 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_test",
|
||||
"jax_multiplatform_test",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
@ -42,7 +42,7 @@ DISABLED_CONFIGS = [
|
||||
"gpu_pjrt_c_api",
|
||||
]
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "matmul_bench",
|
||||
srcs = ["matmul_bench.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
|
@ -16,7 +16,7 @@ load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"cuda_library",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_test",
|
||||
"jax_multiplatform_test",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -28,7 +28,7 @@ package(
|
||||
|
||||
jax_generate_backend_suites()
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "cuda_custom_call_test",
|
||||
srcs = ["cuda_custom_call_test.py"],
|
||||
data = [":foo"],
|
||||
|
@ -18,13 +18,16 @@ Includes the flags from saved_model_main.py.
|
||||
See README.md.
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from jax.experimental.jax2tf.examples import mnist_lib
|
||||
from jax.experimental.jax2tf.examples import saved_model_main
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds # type: ignore
|
||||
import tensorflow_hub as hub # type: ignore
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
import tensorflow_hub as hub # type: ignore
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
@ -41,6 +41,7 @@ class KerasReuseMainTest(tf_test_util.JaxToTfTestCase):
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_{model}", model=model)
|
||||
for model in ["mnist_pure_jax", "mnist_flax"])
|
||||
@jtu.ignore_warning(message="the imp module is deprecated")
|
||||
def test_keras_reuse(self, model="mnist_pure_jax"):
|
||||
FLAGS.model = model
|
||||
keras_reuse_main.main(None)
|
||||
|
@ -27,6 +27,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
import warnings
|
||||
from absl import flags
|
||||
|
||||
import flax
|
||||
@ -70,7 +71,9 @@ def load_mnist(split: tfds.Split, batch_size: int):
|
||||
if _MOCK_DATA.value:
|
||||
with tfds.testing.mock_data(num_examples=batch_size):
|
||||
try:
|
||||
ds = tfds.load("mnist", split=split)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
ds = tfds.load("mnist", split=split)
|
||||
except Exception as e:
|
||||
m = re.search(r'metadata files were not found in (.+/)mnist/', str(e))
|
||||
if m:
|
||||
|
@ -88,6 +88,17 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
# bug in TensorFlow.
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message=(
|
||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||
)
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
super().tearDown()
|
||||
|
||||
@_parameterized_jit
|
||||
def test_eval_scalar_arg(self, with_jit=True):
|
||||
@ -862,6 +873,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
||||
"Reloading output of jax2tf into JAX with call_tf"
|
||||
|
||||
def setUp(self):
|
||||
if tf is None:
|
||||
raise unittest.SkipTest("Test requires tensorflow")
|
||||
@ -869,6 +881,17 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
||||
# bug in TensorFlow.
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message=(
|
||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||
)
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
super().tearDown()
|
||||
|
||||
def test_simple(self):
|
||||
f_jax = jnp.sin
|
||||
@ -1157,6 +1180,17 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
||||
# bug in TensorFlow.
|
||||
_ = tf.add(1, 1)
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message=(
|
||||
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||
)
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
super().tearDown()
|
||||
|
||||
def test_alternate(self):
|
||||
# Alternate sin/cos with sin in TF and cos in JAX
|
||||
|
@ -76,6 +76,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
super().tearDown()
|
||||
|
||||
def test_empty(self):
|
||||
f_jax = lambda x, y: x
|
||||
self.ConvertAndCompare(f_jax, 0.7, 1)
|
||||
@ -1621,6 +1632,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
|
||||
self.assertAllClose(f_jax(*many_args), res)
|
||||
|
||||
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning)
|
||||
def test_nested_convert(self):
|
||||
# Test call sequence: convert -> call_tf -> convert.
|
||||
|
||||
@ -1677,6 +1690,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@jtu.with_config(jax_enable_custom_prng=True)
|
||||
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
super().tearDown()
|
||||
|
||||
def test_key_argument(self):
|
||||
func = lambda key: jax.random.uniform(key, ())
|
||||
key = jax.random.PRNGKey(0)
|
||||
@ -1709,6 +1733,9 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
|
||||
self.use_max_serialization_version = False
|
||||
super().setUp()
|
||||
|
||||
@jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
)
|
||||
def test_simple(self):
|
||||
self.ConvertAndCompare(jnp.sin, 0.7)
|
||||
|
||||
|
@ -30,6 +30,17 @@ jax.config.parse_flags_with_absl()
|
||||
|
||||
class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
super().tearDown()
|
||||
|
||||
def test_eval(self):
|
||||
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
||||
model = tf.Module()
|
||||
|
@ -334,7 +334,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
check_shape_poly(self, f_jax, arg_descriptors=[x],
|
||||
polymorphic_shapes=["b"])
|
||||
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"expr={name}", expr=expr)
|
||||
@ -941,7 +940,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
xi_yf = (xi, yf)
|
||||
zb = np.array([True, False], dtype=np.bool_)
|
||||
def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
|
||||
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
|
||||
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
|
||||
xi, yf = xi_yf
|
||||
# Return a tuple:
|
||||
# (1) float constant, with 0 tangent;
|
||||
@ -1032,6 +1031,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
|
||||
self.assertAllClose(f_jax(x), restored_f(x))
|
||||
|
||||
@jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
)
|
||||
def test_readme_examples(self):
|
||||
"""Some of the examples from the README."""
|
||||
|
||||
|
@ -61,7 +61,8 @@ def setUpModule():
|
||||
|
||||
global topology
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
with jtu.ignore_warning(message="the imp module is deprecated"):
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
# Do TPU init at beginning since it will wipe out all HBMs.
|
||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
@ -84,6 +85,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
raise unittest.SkipTest("Test requires at least 2 local devices")
|
||||
self.devices = np.array(jax.devices()[:2]) # use 2 devices
|
||||
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||
)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
super().tearDown()
|
||||
|
||||
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
|
||||
num_replicas=1, num_partitions=2):
|
||||
"""Log the HLO generated from JAX before and after optimizations"""
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||
load("//jaxlib:jax.bzl", "py_deps")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("//jaxlib:jax.bzl", "jax_py_test", "py_deps")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
@ -48,7 +48,7 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
jax_py_test(
|
||||
name = "run_matmul",
|
||||
srcs = ["matmul.py"],
|
||||
main = "matmul.py",
|
||||
|
@ -19,6 +19,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library",
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
||||
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
|
||||
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
|
||||
load("@rules_python//python:defs.bzl", "py_test")
|
||||
load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
|
||||
load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
|
||||
|
||||
@ -222,7 +223,7 @@ def if_building_jaxlib(
|
||||
})
|
||||
|
||||
# buildifier: disable=function-docstring
|
||||
def jax_test(
|
||||
def jax_multiplatform_test(
|
||||
name,
|
||||
srcs,
|
||||
args = [],
|
||||
@ -300,3 +301,12 @@ jax_test_file_visibility = []
|
||||
|
||||
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
|
||||
pass
|
||||
|
||||
def jax_py_test(
|
||||
name,
|
||||
env = {},
|
||||
**kwargs):
|
||||
env = dict(env)
|
||||
if "PYTHONWARNINGS" not in env:
|
||||
env["PYTHONWARNINGS"] = "error"
|
||||
py_test(name = name, env = env, **kwargs)
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("//jaxlib:jax.bzl", "if_windows")
|
||||
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test")
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
@ -52,7 +52,7 @@ py_binary(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
jax_py_test(
|
||||
name = "build_wheel_test",
|
||||
srcs = ["build_wheel_test.py"],
|
||||
data = [":build_wheel"],
|
||||
|
@ -60,6 +60,11 @@ filterwarnings = [
|
||||
"default:Special cases found for .* but none were parsed.*:UserWarning",
|
||||
"default:.*is not JSON-serializable. Using the repr instead.*:UserWarning",
|
||||
"default:The .* method is good for exploring strategies.*",
|
||||
|
||||
# NOTE: this is probably not where you want to add code to suppress a
|
||||
# warning. Only pytest tests look at this list, whereas Bazel tests also
|
||||
# check for warnings and do not check this list. Most likely, you should
|
||||
# add a @jtu.ignore_warning decorator to your test instead.
|
||||
]
|
||||
doctest_optionflags = [
|
||||
"NUMBER",
|
||||
|
268
tests/BUILD
268
tests/BUILD
File diff suppressed because it is too large
Load Diff
@ -75,6 +75,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
use_stream=[False, True],
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning)
|
||||
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
@ -142,6 +144,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
dtype=dlpack_dtypes,
|
||||
)
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning)
|
||||
def testTensorFlowToJax(self, shape, dtype):
|
||||
if (not config.enable_x64.value and
|
||||
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
|
||||
@ -184,6 +188,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(np, y.numpy())
|
||||
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning)
|
||||
def testTensorFlowToJaxInt64(self):
|
||||
# See https://github.com/jax-ml/jax/issues/11895
|
||||
x = jax.dlpack.from_dlpack(
|
||||
|
@ -176,6 +176,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
testcase_name=f"_{ad=}",
|
||||
ad=ad)
|
||||
for ad in CALL_TF_IMPLEMENTATIONS.keys())
|
||||
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||
category=DeprecationWarning)
|
||||
def test_impl(self, ad="simple"):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
||||
@ -197,6 +199,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
ad=ad)
|
||||
for ad in CALL_TF_IMPLEMENTATIONS.keys()
|
||||
if ad != "none")
|
||||
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||
category=DeprecationWarning)
|
||||
def test_grad(self, ad="simple"):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
||||
@ -217,6 +221,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jax.grad(f_jax)(x), grad_f,
|
||||
check_dtypes=False)
|
||||
|
||||
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||
category=DeprecationWarning)
|
||||
def test_grad_pytree(self):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = call_tf_full_ad
|
||||
@ -246,6 +252,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
||||
testcase_name=f"_degree=_{degree}",
|
||||
degree=degree)
|
||||
for degree in [1, 2, 3, 4])
|
||||
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||
category=DeprecationWarning)
|
||||
def test_higher_order_grad(self, degree=4):
|
||||
self.supported_only_in_legacy_mode()
|
||||
call_tf = call_tf_full_ad
|
||||
|
@ -2681,6 +2681,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
shape=all_shapes,
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
@jtu.ignore_warning(category=RuntimeWarning, message="overflow")
|
||||
def testFrexp(self, shape, dtype, rng_factory):
|
||||
# integer types are converted to float64 in numpy's implementation
|
||||
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
|
||||
@ -6270,7 +6271,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
|
||||
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
|
||||
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
||||
try:
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"):
|
||||
with jtu.ignore_warning(
|
||||
category=RuntimeWarning, message="(divide by zero|invalid value)"):
|
||||
_ = func(*args)
|
||||
except TypeError:
|
||||
pass
|
||||
@ -6292,7 +6294,7 @@ class NumpyUfuncTests(jtu.JaxTestCase):
|
||||
jnp_op = getattr(jnp, name)
|
||||
np_op = getattr(np, name)
|
||||
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="divide by zero.*")(np_op)
|
||||
message="(divide by zero|invalid value)")(np_op)
|
||||
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
||||
|
||||
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
||||
|
@ -110,6 +110,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for shape_group in lax_test_util.compatible_shapes),
|
||||
dtype=rec.dtypes)
|
||||
for rec in lax_test_util.lax_ops()))
|
||||
@jtu.ignore_warning(message="invalid value", category=RuntimeWarning)
|
||||
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
||||
if (not config.enable_x64.value and op_name == "nextafter"
|
||||
and dtype == np.float64):
|
||||
|
@ -15,7 +15,7 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_test",
|
||||
"jax_multiplatform_test",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
@ -43,7 +43,7 @@ DISABLED_CONFIGS = [
|
||||
"gpu",
|
||||
]
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "gpu_test",
|
||||
srcs = ["gpu_test.py"],
|
||||
config_tags_overrides = {
|
||||
@ -63,7 +63,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "matmul_test",
|
||||
srcs = ["matmul_test.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
@ -75,7 +75,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "flash_attention",
|
||||
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
@ -87,7 +87,7 @@ jax_test(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "flash_attention_test",
|
||||
srcs = ["flash_attention_test.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
|
@ -1200,6 +1200,8 @@ class FragmentedArrayTest(TestCase):
|
||||
m=(64, 128),
|
||||
n=(8, 16, 32, 64, 80, 128, 256),
|
||||
)
|
||||
@jtu.ignore_warning(message="(invalid value|divide by zero)",
|
||||
category=RuntimeWarning)
|
||||
def test_binary(self, op, dtype, m=64, n=32):
|
||||
if isinstance(op, tuple):
|
||||
op, np_op = op
|
||||
@ -1294,6 +1296,7 @@ class FragmentedArrayTest(TestCase):
|
||||
],
|
||||
approx=[False, True],
|
||||
)
|
||||
@jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning)
|
||||
def test_math(self, ops, approx, m=64, n=32):
|
||||
op, np_op = ops
|
||||
def kernel(ctx, dst, _):
|
||||
|
@ -15,7 +15,7 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_test",
|
||||
"jax_multiplatform_test",
|
||||
"py_deps",
|
||||
)
|
||||
|
||||
@ -28,7 +28,7 @@ package(
|
||||
|
||||
jax_generate_backend_suites()
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "pallas_test",
|
||||
srcs = [
|
||||
"pallas_test.py",
|
||||
@ -62,7 +62,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "pallas_jumble_test",
|
||||
srcs = [
|
||||
"pallas_jumble_test.py",
|
||||
@ -85,7 +85,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "ops_test",
|
||||
srcs = [
|
||||
"ops_test.py",
|
||||
@ -125,7 +125,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "indexing_test",
|
||||
srcs = [
|
||||
"indexing_test.py",
|
||||
@ -144,7 +144,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "pallas_vmap_test",
|
||||
srcs = [
|
||||
"pallas_vmap_test.py",
|
||||
@ -176,7 +176,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "mosaic_gpu_test",
|
||||
srcs = [
|
||||
"mosaic_gpu_test.py",
|
||||
@ -213,7 +213,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "export_back_compat_pallas_test",
|
||||
srcs = ["export_back_compat_pallas_test.py"],
|
||||
config_tags_overrides = {
|
||||
@ -244,7 +244,7 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "export_pallas_test",
|
||||
srcs = ["export_pallas_test.py"],
|
||||
config_tags_overrides = {
|
||||
@ -272,7 +272,7 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "pallas_shape_poly_test",
|
||||
srcs = ["pallas_shape_poly_test.py"],
|
||||
config_tags_overrides = {
|
||||
@ -299,7 +299,7 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "pallas_error_handling_test",
|
||||
srcs = [
|
||||
"pallas_error_handling_test.py",
|
||||
@ -317,7 +317,7 @@ jax_test(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_all_gather_test",
|
||||
srcs = [
|
||||
"tpu_all_gather_test.py",
|
||||
@ -331,7 +331,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_gmm_test",
|
||||
srcs = [
|
||||
"tpu_gmm_test.py",
|
||||
@ -356,7 +356,7 @@ jax_test(
|
||||
]),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_test",
|
||||
srcs = ["tpu_pallas_test.py"],
|
||||
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
|
||||
@ -372,7 +372,7 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_ops_test",
|
||||
srcs = [
|
||||
"tpu_ops_test.py",
|
||||
@ -388,7 +388,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_distributed_test",
|
||||
srcs = ["tpu_pallas_distributed_test.py"],
|
||||
disable_backends = [
|
||||
@ -402,7 +402,7 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_pipeline_test",
|
||||
srcs = ["tpu_pallas_pipeline_test.py"],
|
||||
disable_backends = [
|
||||
@ -422,7 +422,7 @@ jax_test(
|
||||
] + py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_async_test",
|
||||
srcs = ["tpu_pallas_async_test.py"],
|
||||
disable_backends = [
|
||||
@ -436,7 +436,7 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_mesh_test",
|
||||
srcs = ["tpu_pallas_mesh_test.py"],
|
||||
disable_backends = [
|
||||
@ -454,7 +454,7 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_random_test",
|
||||
srcs = [
|
||||
"tpu_pallas_random_test.py",
|
||||
@ -472,7 +472,7 @@ jax_test(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_paged_attention_kernel_test",
|
||||
srcs = ["tpu_paged_attention_kernel_test.py"],
|
||||
disable_backends = [
|
||||
@ -490,7 +490,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_splash_attention_kernel_test",
|
||||
srcs = [
|
||||
"tpu_splash_attention_kernel_test.py",
|
||||
@ -510,7 +510,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_splash_attention_mask_test",
|
||||
srcs = [
|
||||
"tpu_splash_attention_mask_test.py",
|
||||
@ -523,7 +523,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "gpu_attention_test",
|
||||
srcs = [
|
||||
"gpu_attention_test.py",
|
||||
@ -556,7 +556,7 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
jax_multiplatform_test(
|
||||
name = "gpu_ops_test",
|
||||
srcs = [
|
||||
"gpu_ops_test.py",
|
||||
|
@ -3209,8 +3209,12 @@ class EagerPmapMixin:
|
||||
self.jit_disabled = config.disable_jit.value
|
||||
config.update('jax_disable_jit', True)
|
||||
config.update('jax_eager_pmap', True)
|
||||
self.warning_ctx = jtu.ignore_warning(
|
||||
message="Some donated buffers were not usable", category=UserWarning)
|
||||
self.warning_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_ctx.__exit__(None, None, None)
|
||||
config.update('jax_eager_pmap', self.eager_pmap_enabled)
|
||||
config.update('jax_disable_jit', self.jit_disabled)
|
||||
super().tearDown()
|
||||
|
@ -108,6 +108,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
else:
|
||||
self.assertAllClose(np, y.cpu().numpy())
|
||||
|
||||
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning)
|
||||
def testTorchToJaxInt64(self):
|
||||
# See https://github.com/jax-ml/jax/issues/11895
|
||||
x = jax.dlpack.from_dlpack(
|
||||
@ -116,6 +118,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.assertEqual(x.dtype, dtype_expected)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
|
||||
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||
category=DeprecationWarning)
|
||||
def testTorchToJax(self, shape, dtype):
|
||||
if not config.enable_x64.value and dtype in [
|
||||
jnp.int64,
|
||||
|
@ -973,6 +973,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self.assertArraysAllClose(out.todense(), expected_out)
|
||||
self.assertEqual(out.nse, expected_nse)
|
||||
|
||||
@jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available")
|
||||
def test_bcoo_spdot_general_ad_bug(self):
|
||||
# Regression test for https://github.com/jax-ml/jax/issues/10163
|
||||
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])
|
||||
|
Loading…
x
Reference in New Issue
Block a user