mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner. To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable. Add code to suppress some new warnings uncovered in CI. PiperOrigin-RevId: 678352286
This commit is contained in:
parent
d58a09faed
commit
70f91db853
@ -15,7 +15,7 @@
|
|||||||
load(
|
load(
|
||||||
"//jaxlib:jax.bzl",
|
"//jaxlib:jax.bzl",
|
||||||
"jax_generate_backend_suites",
|
"jax_generate_backend_suites",
|
||||||
"jax_test",
|
"jax_multiplatform_test",
|
||||||
"py_deps",
|
"py_deps",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ DISABLED_CONFIGS = [
|
|||||||
"gpu_pjrt_c_api",
|
"gpu_pjrt_c_api",
|
||||||
]
|
]
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "matmul_bench",
|
name = "matmul_bench",
|
||||||
srcs = ["matmul_bench.py"],
|
srcs = ["matmul_bench.py"],
|
||||||
disable_backends = DISABLED_BACKENDS,
|
disable_backends = DISABLED_BACKENDS,
|
||||||
|
@ -16,7 +16,7 @@ load(
|
|||||||
"//jaxlib:jax.bzl",
|
"//jaxlib:jax.bzl",
|
||||||
"cuda_library",
|
"cuda_library",
|
||||||
"jax_generate_backend_suites",
|
"jax_generate_backend_suites",
|
||||||
"jax_test",
|
"jax_multiplatform_test",
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
@ -28,7 +28,7 @@ package(
|
|||||||
|
|
||||||
jax_generate_backend_suites()
|
jax_generate_backend_suites()
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "cuda_custom_call_test",
|
name = "cuda_custom_call_test",
|
||||||
srcs = ["cuda_custom_call_test.py"],
|
srcs = ["cuda_custom_call_test.py"],
|
||||||
data = [":foo"],
|
data = [":foo"],
|
||||||
|
@ -18,13 +18,16 @@ Includes the flags from saved_model_main.py.
|
|||||||
See README.md.
|
See README.md.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
from jax.experimental.jax2tf.examples import mnist_lib
|
from jax.experimental.jax2tf.examples import mnist_lib
|
||||||
from jax.experimental.jax2tf.examples import saved_model_main
|
from jax.experimental.jax2tf.examples import saved_model_main
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds # type: ignore
|
import tensorflow_datasets as tfds # type: ignore
|
||||||
import tensorflow_hub as hub # type: ignore
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
import tensorflow_hub as hub # type: ignore
|
||||||
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
@ -41,6 +41,7 @@ class KerasReuseMainTest(tf_test_util.JaxToTfTestCase):
|
|||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(testcase_name=f"_{model}", model=model)
|
dict(testcase_name=f"_{model}", model=model)
|
||||||
for model in ["mnist_pure_jax", "mnist_flax"])
|
for model in ["mnist_pure_jax", "mnist_flax"])
|
||||||
|
@jtu.ignore_warning(message="the imp module is deprecated")
|
||||||
def test_keras_reuse(self, model="mnist_pure_jax"):
|
def test_keras_reuse(self, model="mnist_pure_jax"):
|
||||||
FLAGS.model = model
|
FLAGS.model = model
|
||||||
keras_reuse_main.main(None)
|
keras_reuse_main.main(None)
|
||||||
|
@ -27,6 +27,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import warnings
|
||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
@ -70,7 +71,9 @@ def load_mnist(split: tfds.Split, batch_size: int):
|
|||||||
if _MOCK_DATA.value:
|
if _MOCK_DATA.value:
|
||||||
with tfds.testing.mock_data(num_examples=batch_size):
|
with tfds.testing.mock_data(num_examples=batch_size):
|
||||||
try:
|
try:
|
||||||
ds = tfds.load("mnist", split=split)
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
ds = tfds.load("mnist", split=split)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
m = re.search(r'metadata files were not found in (.+/)mnist/', str(e))
|
m = re.search(r'metadata files were not found in (.+/)mnist/', str(e))
|
||||||
if m:
|
if m:
|
||||||
|
@ -88,6 +88,17 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
# bug in TensorFlow.
|
# bug in TensorFlow.
|
||||||
_ = tf.add(1, 1)
|
_ = tf.add(1, 1)
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message=(
|
||||||
|
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
@_parameterized_jit
|
@_parameterized_jit
|
||||||
def test_eval_scalar_arg(self, with_jit=True):
|
def test_eval_scalar_arg(self, with_jit=True):
|
||||||
@ -862,6 +873,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
|
|
||||||
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
||||||
"Reloading output of jax2tf into JAX with call_tf"
|
"Reloading output of jax2tf into JAX with call_tf"
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if tf is None:
|
if tf is None:
|
||||||
raise unittest.SkipTest("Test requires tensorflow")
|
raise unittest.SkipTest("Test requires tensorflow")
|
||||||
@ -869,6 +881,17 @@ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
|||||||
# bug in TensorFlow.
|
# bug in TensorFlow.
|
||||||
_ = tf.add(1, 1)
|
_ = tf.add(1, 1)
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message=(
|
||||||
|
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
f_jax = jnp.sin
|
f_jax = jnp.sin
|
||||||
@ -1157,6 +1180,17 @@ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
# bug in TensorFlow.
|
# bug in TensorFlow.
|
||||||
_ = tf.add(1, 1)
|
_ = tf.add(1, 1)
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message=(
|
||||||
|
"(jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
"|Calling from_dlpack with a DLPack tensor is deprecated)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def test_alternate(self):
|
def test_alternate(self):
|
||||||
# Alternate sin/cos with sin in TF and cos in JAX
|
# Alternate sin/cos with sin in TF and cos in JAX
|
||||||
|
@ -76,6 +76,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
|
|
||||||
super().setUpClass()
|
super().setUpClass()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
f_jax = lambda x, y: x
|
f_jax = lambda x, y: x
|
||||||
self.ConvertAndCompare(f_jax, 0.7, 1)
|
self.ConvertAndCompare(f_jax, 0.7, 1)
|
||||||
@ -1621,6 +1632,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
|
res = jax2tf.convert(f_jax, native_serialization=True)(*many_args)
|
||||||
self.assertAllClose(f_jax(*many_args), res)
|
self.assertAllClose(f_jax(*many_args), res)
|
||||||
|
|
||||||
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||||
|
category=DeprecationWarning)
|
||||||
def test_nested_convert(self):
|
def test_nested_convert(self):
|
||||||
# Test call sequence: convert -> call_tf -> convert.
|
# Test call sequence: convert -> call_tf -> convert.
|
||||||
|
|
||||||
@ -1677,6 +1690,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
|||||||
|
|
||||||
@jtu.with_config(jax_enable_custom_prng=True)
|
@jtu.with_config(jax_enable_custom_prng=True)
|
||||||
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def test_key_argument(self):
|
def test_key_argument(self):
|
||||||
func = lambda key: jax.random.uniform(key, ())
|
func = lambda key: jax.random.uniform(key, ())
|
||||||
key = jax.random.PRNGKey(0)
|
key = jax.random.PRNGKey(0)
|
||||||
@ -1709,6 +1733,9 @@ class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
|
|||||||
self.use_max_serialization_version = False
|
self.use_max_serialization_version = False
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
|
@jtu.ignore_warning(
|
||||||
|
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
)
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
self.ConvertAndCompare(jnp.sin, 0.7)
|
self.ConvertAndCompare(jnp.sin, 0.7)
|
||||||
|
|
||||||
|
@ -30,6 +30,17 @@ jax.config.parse_flags_with_absl()
|
|||||||
|
|
||||||
class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
class SavedModelTest(tf_test_util.JaxToTfTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def test_eval(self):
|
def test_eval(self):
|
||||||
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
|
||||||
model = tf.Module()
|
model = tf.Module()
|
||||||
|
@ -334,7 +334,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
|||||||
check_shape_poly(self, f_jax, arg_descriptors=[x],
|
check_shape_poly(self, f_jax, arg_descriptors=[x],
|
||||||
polymorphic_shapes=["b"])
|
polymorphic_shapes=["b"])
|
||||||
|
|
||||||
|
|
||||||
@jtu.parameterized_filterable(
|
@jtu.parameterized_filterable(
|
||||||
kwargs=[
|
kwargs=[
|
||||||
dict(testcase_name=f"expr={name}", expr=expr)
|
dict(testcase_name=f"expr={name}", expr=expr)
|
||||||
@ -941,7 +940,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
|||||||
xi_yf = (xi, yf)
|
xi_yf = (xi, yf)
|
||||||
zb = np.array([True, False], dtype=np.bool_)
|
zb = np.array([True, False], dtype=np.bool_)
|
||||||
def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
|
def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
|
||||||
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
|
# results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4]
|
||||||
xi, yf = xi_yf
|
xi, yf = xi_yf
|
||||||
# Return a tuple:
|
# Return a tuple:
|
||||||
# (1) float constant, with 0 tangent;
|
# (1) float constant, with 0 tangent;
|
||||||
@ -1032,6 +1031,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
|||||||
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
|
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
|
||||||
self.assertAllClose(f_jax(x), restored_f(x))
|
self.assertAllClose(f_jax(x), restored_f(x))
|
||||||
|
|
||||||
|
@jtu.ignore_warning(
|
||||||
|
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
)
|
||||||
def test_readme_examples(self):
|
def test_readme_examples(self):
|
||||||
"""Some of the examples from the README."""
|
"""Some of the examples from the README."""
|
||||||
|
|
||||||
|
@ -61,7 +61,8 @@ def setUpModule():
|
|||||||
|
|
||||||
global topology
|
global topology
|
||||||
if jtu.test_device_matches(["tpu"]):
|
if jtu.test_device_matches(["tpu"]):
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
with jtu.ignore_warning(message="the imp module is deprecated"):
|
||||||
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
tf.config.experimental_connect_to_cluster(resolver)
|
tf.config.experimental_connect_to_cluster(resolver)
|
||||||
# Do TPU init at beginning since it will wipe out all HBMs.
|
# Do TPU init at beginning since it will wipe out all HBMs.
|
||||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
@ -84,6 +85,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
|||||||
raise unittest.SkipTest("Test requires at least 2 local devices")
|
raise unittest.SkipTest("Test requires at least 2 local devices")
|
||||||
self.devices = np.array(jax.devices()[:2]) # use 2 devices
|
self.devices = np.array(jax.devices()[:2]) # use 2 devices
|
||||||
|
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message="jax2tf.convert with native_serialization=False is deprecated"
|
||||||
|
)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
|
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
|
||||||
num_replicas=1, num_partitions=2):
|
num_replicas=1, num_partitions=2):
|
||||||
"""Log the HLO generated from JAX before and after optimizations"""
|
"""Log the HLO generated from JAX before and after optimizations"""
|
||||||
|
@ -12,8 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
load("@rules_python//python:defs.bzl", "py_library")
|
||||||
load("//jaxlib:jax.bzl", "py_deps")
|
load("//jaxlib:jax.bzl", "jax_py_test", "py_deps")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
jax_py_test(
|
||||||
name = "run_matmul",
|
name = "run_matmul",
|
||||||
srcs = ["matmul.py"],
|
srcs = ["matmul.py"],
|
||||||
main = "matmul.py",
|
main = "matmul.py",
|
||||||
|
@ -19,6 +19,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library",
|
|||||||
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
||||||
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
|
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
|
||||||
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
|
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
|
||||||
|
load("@rules_python//python:defs.bzl", "py_test")
|
||||||
load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
|
load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
|
||||||
load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
|
load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
|
||||||
|
|
||||||
@ -222,7 +223,7 @@ def if_building_jaxlib(
|
|||||||
})
|
})
|
||||||
|
|
||||||
# buildifier: disable=function-docstring
|
# buildifier: disable=function-docstring
|
||||||
def jax_test(
|
def jax_multiplatform_test(
|
||||||
name,
|
name,
|
||||||
srcs,
|
srcs,
|
||||||
args = [],
|
args = [],
|
||||||
@ -300,3 +301,12 @@ jax_test_file_visibility = []
|
|||||||
|
|
||||||
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
|
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def jax_py_test(
|
||||||
|
name,
|
||||||
|
env = {},
|
||||||
|
**kwargs):
|
||||||
|
env = dict(env)
|
||||||
|
if "PYTHONWARNINGS" not in env:
|
||||||
|
env["PYTHONWARNINGS"] = "error"
|
||||||
|
py_test(name = name, env = env, **kwargs)
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||||
load("//jaxlib:jax.bzl", "if_windows")
|
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test")
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2
|
licenses(["notice"]) # Apache 2
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ py_binary(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
jax_py_test(
|
||||||
name = "build_wheel_test",
|
name = "build_wheel_test",
|
||||||
srcs = ["build_wheel_test.py"],
|
srcs = ["build_wheel_test.py"],
|
||||||
data = [":build_wheel"],
|
data = [":build_wheel"],
|
||||||
|
@ -60,6 +60,11 @@ filterwarnings = [
|
|||||||
"default:Special cases found for .* but none were parsed.*:UserWarning",
|
"default:Special cases found for .* but none were parsed.*:UserWarning",
|
||||||
"default:.*is not JSON-serializable. Using the repr instead.*:UserWarning",
|
"default:.*is not JSON-serializable. Using the repr instead.*:UserWarning",
|
||||||
"default:The .* method is good for exploring strategies.*",
|
"default:The .* method is good for exploring strategies.*",
|
||||||
|
|
||||||
|
# NOTE: this is probably not where you want to add code to suppress a
|
||||||
|
# warning. Only pytest tests look at this list, whereas Bazel tests also
|
||||||
|
# check for warnings and do not check this list. Most likely, you should
|
||||||
|
# add a @jtu.ignore_warning decorator to your test instead.
|
||||||
]
|
]
|
||||||
doctest_optionflags = [
|
doctest_optionflags = [
|
||||||
"NUMBER",
|
"NUMBER",
|
||||||
|
268
tests/BUILD
268
tests/BUILD
File diff suppressed because it is too large
Load Diff
@ -75,6 +75,8 @@ class DLPackTest(jtu.JaxTestCase):
|
|||||||
use_stream=[False, True],
|
use_stream=[False, True],
|
||||||
)
|
)
|
||||||
@jtu.run_on_devices("gpu")
|
@jtu.run_on_devices("gpu")
|
||||||
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||||
|
category=DeprecationWarning)
|
||||||
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
|
def testJaxRoundTrip(self, shape, dtype, copy, use_stream):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
np = rng(shape, dtype)
|
np = rng(shape, dtype)
|
||||||
@ -142,6 +144,8 @@ class DLPackTest(jtu.JaxTestCase):
|
|||||||
dtype=dlpack_dtypes,
|
dtype=dlpack_dtypes,
|
||||||
)
|
)
|
||||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||||
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||||
|
category=DeprecationWarning)
|
||||||
def testTensorFlowToJax(self, shape, dtype):
|
def testTensorFlowToJax(self, shape, dtype):
|
||||||
if (not config.enable_x64.value and
|
if (not config.enable_x64.value and
|
||||||
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
|
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
|
||||||
@ -184,6 +188,8 @@ class DLPackTest(jtu.JaxTestCase):
|
|||||||
self.assertAllClose(np, y.numpy())
|
self.assertAllClose(np, y.numpy())
|
||||||
|
|
||||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||||
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||||
|
category=DeprecationWarning)
|
||||||
def testTensorFlowToJaxInt64(self):
|
def testTensorFlowToJaxInt64(self):
|
||||||
# See https://github.com/jax-ml/jax/issues/11895
|
# See https://github.com/jax-ml/jax/issues/11895
|
||||||
x = jax.dlpack.from_dlpack(
|
x = jax.dlpack.from_dlpack(
|
||||||
|
@ -176,6 +176,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
|||||||
testcase_name=f"_{ad=}",
|
testcase_name=f"_{ad=}",
|
||||||
ad=ad)
|
ad=ad)
|
||||||
for ad in CALL_TF_IMPLEMENTATIONS.keys())
|
for ad in CALL_TF_IMPLEMENTATIONS.keys())
|
||||||
|
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||||
|
category=DeprecationWarning)
|
||||||
def test_impl(self, ad="simple"):
|
def test_impl(self, ad="simple"):
|
||||||
self.supported_only_in_legacy_mode()
|
self.supported_only_in_legacy_mode()
|
||||||
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
||||||
@ -197,6 +199,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
|||||||
ad=ad)
|
ad=ad)
|
||||||
for ad in CALL_TF_IMPLEMENTATIONS.keys()
|
for ad in CALL_TF_IMPLEMENTATIONS.keys()
|
||||||
if ad != "none")
|
if ad != "none")
|
||||||
|
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||||
|
category=DeprecationWarning)
|
||||||
def test_grad(self, ad="simple"):
|
def test_grad(self, ad="simple"):
|
||||||
self.supported_only_in_legacy_mode()
|
self.supported_only_in_legacy_mode()
|
||||||
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
call_tf = CALL_TF_IMPLEMENTATIONS[ad]
|
||||||
@ -217,6 +221,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
|||||||
self.assertAllClose(jax.grad(f_jax)(x), grad_f,
|
self.assertAllClose(jax.grad(f_jax)(x), grad_f,
|
||||||
check_dtypes=False)
|
check_dtypes=False)
|
||||||
|
|
||||||
|
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||||
|
category=DeprecationWarning)
|
||||||
def test_grad_pytree(self):
|
def test_grad_pytree(self):
|
||||||
self.supported_only_in_legacy_mode()
|
self.supported_only_in_legacy_mode()
|
||||||
call_tf = call_tf_full_ad
|
call_tf = call_tf_full_ad
|
||||||
@ -246,6 +252,8 @@ class CallToTFTest(jtu.JaxTestCase):
|
|||||||
testcase_name=f"_degree=_{degree}",
|
testcase_name=f"_degree=_{degree}",
|
||||||
degree=degree)
|
degree=degree)
|
||||||
for degree in [1, 2, 3, 4])
|
for degree in [1, 2, 3, 4])
|
||||||
|
@jtu.ignore_warning(message="The host_callback APIs are deprecated",
|
||||||
|
category=DeprecationWarning)
|
||||||
def test_higher_order_grad(self, degree=4):
|
def test_higher_order_grad(self, degree=4):
|
||||||
self.supported_only_in_legacy_mode()
|
self.supported_only_in_legacy_mode()
|
||||||
call_tf = call_tf_full_ad
|
call_tf = call_tf_full_ad
|
||||||
|
@ -2681,6 +2681,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
shape=all_shapes,
|
shape=all_shapes,
|
||||||
dtype=default_dtypes,
|
dtype=default_dtypes,
|
||||||
)
|
)
|
||||||
|
@jtu.ignore_warning(category=RuntimeWarning, message="overflow")
|
||||||
def testFrexp(self, shape, dtype, rng_factory):
|
def testFrexp(self, shape, dtype, rng_factory):
|
||||||
# integer types are converted to float64 in numpy's implementation
|
# integer types are converted to float64 in numpy's implementation
|
||||||
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
|
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
|
||||||
@ -6270,7 +6271,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
|
|||||||
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
|
for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin):
|
||||||
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
||||||
try:
|
try:
|
||||||
with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"):
|
with jtu.ignore_warning(
|
||||||
|
category=RuntimeWarning, message="(divide by zero|invalid value)"):
|
||||||
_ = func(*args)
|
_ = func(*args)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pass
|
pass
|
||||||
@ -6292,7 +6294,7 @@ class NumpyUfuncTests(jtu.JaxTestCase):
|
|||||||
jnp_op = getattr(jnp, name)
|
jnp_op = getattr(jnp, name)
|
||||||
np_op = getattr(np, name)
|
np_op = getattr(np, name)
|
||||||
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="divide by zero.*")(np_op)
|
message="(divide by zero|invalid value)")(np_op)
|
||||||
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes)
|
||||||
|
|
||||||
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
with jtu.strict_promotion_if_dtypes_match(arg_dtypes):
|
||||||
|
@ -110,6 +110,7 @@ class LaxTest(jtu.JaxTestCase):
|
|||||||
for shape_group in lax_test_util.compatible_shapes),
|
for shape_group in lax_test_util.compatible_shapes),
|
||||||
dtype=rec.dtypes)
|
dtype=rec.dtypes)
|
||||||
for rec in lax_test_util.lax_ops()))
|
for rec in lax_test_util.lax_ops()))
|
||||||
|
@jtu.ignore_warning(message="invalid value", category=RuntimeWarning)
|
||||||
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
||||||
if (not config.enable_x64.value and op_name == "nextafter"
|
if (not config.enable_x64.value and op_name == "nextafter"
|
||||||
and dtype == np.float64):
|
and dtype == np.float64):
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
load(
|
load(
|
||||||
"//jaxlib:jax.bzl",
|
"//jaxlib:jax.bzl",
|
||||||
"jax_generate_backend_suites",
|
"jax_generate_backend_suites",
|
||||||
"jax_test",
|
"jax_multiplatform_test",
|
||||||
"py_deps",
|
"py_deps",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ DISABLED_CONFIGS = [
|
|||||||
"gpu",
|
"gpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "gpu_test",
|
name = "gpu_test",
|
||||||
srcs = ["gpu_test.py"],
|
srcs = ["gpu_test.py"],
|
||||||
config_tags_overrides = {
|
config_tags_overrides = {
|
||||||
@ -63,7 +63,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "matmul_test",
|
name = "matmul_test",
|
||||||
srcs = ["matmul_test.py"],
|
srcs = ["matmul_test.py"],
|
||||||
disable_backends = DISABLED_BACKENDS,
|
disable_backends = DISABLED_BACKENDS,
|
||||||
@ -75,7 +75,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "flash_attention",
|
name = "flash_attention",
|
||||||
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
|
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
|
||||||
disable_backends = DISABLED_BACKENDS,
|
disable_backends = DISABLED_BACKENDS,
|
||||||
@ -87,7 +87,7 @@ jax_test(
|
|||||||
] + py_deps("numpy"),
|
] + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "flash_attention_test",
|
name = "flash_attention_test",
|
||||||
srcs = ["flash_attention_test.py"],
|
srcs = ["flash_attention_test.py"],
|
||||||
disable_backends = DISABLED_BACKENDS,
|
disable_backends = DISABLED_BACKENDS,
|
||||||
|
@ -1200,6 +1200,8 @@ class FragmentedArrayTest(TestCase):
|
|||||||
m=(64, 128),
|
m=(64, 128),
|
||||||
n=(8, 16, 32, 64, 80, 128, 256),
|
n=(8, 16, 32, 64, 80, 128, 256),
|
||||||
)
|
)
|
||||||
|
@jtu.ignore_warning(message="(invalid value|divide by zero)",
|
||||||
|
category=RuntimeWarning)
|
||||||
def test_binary(self, op, dtype, m=64, n=32):
|
def test_binary(self, op, dtype, m=64, n=32):
|
||||||
if isinstance(op, tuple):
|
if isinstance(op, tuple):
|
||||||
op, np_op = op
|
op, np_op = op
|
||||||
@ -1294,6 +1296,7 @@ class FragmentedArrayTest(TestCase):
|
|||||||
],
|
],
|
||||||
approx=[False, True],
|
approx=[False, True],
|
||||||
)
|
)
|
||||||
|
@jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning)
|
||||||
def test_math(self, ops, approx, m=64, n=32):
|
def test_math(self, ops, approx, m=64, n=32):
|
||||||
op, np_op = ops
|
op, np_op = ops
|
||||||
def kernel(ctx, dst, _):
|
def kernel(ctx, dst, _):
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
load(
|
load(
|
||||||
"//jaxlib:jax.bzl",
|
"//jaxlib:jax.bzl",
|
||||||
"jax_generate_backend_suites",
|
"jax_generate_backend_suites",
|
||||||
"jax_test",
|
"jax_multiplatform_test",
|
||||||
"py_deps",
|
"py_deps",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ package(
|
|||||||
|
|
||||||
jax_generate_backend_suites()
|
jax_generate_backend_suites()
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "pallas_test",
|
name = "pallas_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"pallas_test.py",
|
"pallas_test.py",
|
||||||
@ -62,7 +62,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "pallas_jumble_test",
|
name = "pallas_jumble_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"pallas_jumble_test.py",
|
"pallas_jumble_test.py",
|
||||||
@ -85,7 +85,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "ops_test",
|
name = "ops_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"ops_test.py",
|
"ops_test.py",
|
||||||
@ -125,7 +125,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "indexing_test",
|
name = "indexing_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"indexing_test.py",
|
"indexing_test.py",
|
||||||
@ -144,7 +144,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "pallas_vmap_test",
|
name = "pallas_vmap_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"pallas_vmap_test.py",
|
"pallas_vmap_test.py",
|
||||||
@ -176,7 +176,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "mosaic_gpu_test",
|
name = "mosaic_gpu_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"mosaic_gpu_test.py",
|
"mosaic_gpu_test.py",
|
||||||
@ -213,7 +213,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "export_back_compat_pallas_test",
|
name = "export_back_compat_pallas_test",
|
||||||
srcs = ["export_back_compat_pallas_test.py"],
|
srcs = ["export_back_compat_pallas_test.py"],
|
||||||
config_tags_overrides = {
|
config_tags_overrides = {
|
||||||
@ -244,7 +244,7 @@ jax_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "export_pallas_test",
|
name = "export_pallas_test",
|
||||||
srcs = ["export_pallas_test.py"],
|
srcs = ["export_pallas_test.py"],
|
||||||
config_tags_overrides = {
|
config_tags_overrides = {
|
||||||
@ -272,7 +272,7 @@ jax_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "pallas_shape_poly_test",
|
name = "pallas_shape_poly_test",
|
||||||
srcs = ["pallas_shape_poly_test.py"],
|
srcs = ["pallas_shape_poly_test.py"],
|
||||||
config_tags_overrides = {
|
config_tags_overrides = {
|
||||||
@ -299,7 +299,7 @@ jax_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "pallas_error_handling_test",
|
name = "pallas_error_handling_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"pallas_error_handling_test.py",
|
"pallas_error_handling_test.py",
|
||||||
@ -317,7 +317,7 @@ jax_test(
|
|||||||
] + py_deps("numpy"),
|
] + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_all_gather_test",
|
name = "tpu_all_gather_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tpu_all_gather_test.py",
|
"tpu_all_gather_test.py",
|
||||||
@ -331,7 +331,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_gmm_test",
|
name = "tpu_gmm_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tpu_gmm_test.py",
|
"tpu_gmm_test.py",
|
||||||
@ -356,7 +356,7 @@ jax_test(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_pallas_test",
|
name = "tpu_pallas_test",
|
||||||
srcs = ["tpu_pallas_test.py"],
|
srcs = ["tpu_pallas_test.py"],
|
||||||
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
|
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
|
||||||
@ -372,7 +372,7 @@ jax_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_ops_test",
|
name = "tpu_ops_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tpu_ops_test.py",
|
"tpu_ops_test.py",
|
||||||
@ -388,7 +388,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_pallas_distributed_test",
|
name = "tpu_pallas_distributed_test",
|
||||||
srcs = ["tpu_pallas_distributed_test.py"],
|
srcs = ["tpu_pallas_distributed_test.py"],
|
||||||
disable_backends = [
|
disable_backends = [
|
||||||
@ -402,7 +402,7 @@ jax_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_pallas_pipeline_test",
|
name = "tpu_pallas_pipeline_test",
|
||||||
srcs = ["tpu_pallas_pipeline_test.py"],
|
srcs = ["tpu_pallas_pipeline_test.py"],
|
||||||
disable_backends = [
|
disable_backends = [
|
||||||
@ -422,7 +422,7 @@ jax_test(
|
|||||||
] + py_deps("hypothesis"),
|
] + py_deps("hypothesis"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_pallas_async_test",
|
name = "tpu_pallas_async_test",
|
||||||
srcs = ["tpu_pallas_async_test.py"],
|
srcs = ["tpu_pallas_async_test.py"],
|
||||||
disable_backends = [
|
disable_backends = [
|
||||||
@ -436,7 +436,7 @@ jax_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_pallas_mesh_test",
|
name = "tpu_pallas_mesh_test",
|
||||||
srcs = ["tpu_pallas_mesh_test.py"],
|
srcs = ["tpu_pallas_mesh_test.py"],
|
||||||
disable_backends = [
|
disable_backends = [
|
||||||
@ -454,7 +454,7 @@ jax_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_pallas_random_test",
|
name = "tpu_pallas_random_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tpu_pallas_random_test.py",
|
"tpu_pallas_random_test.py",
|
||||||
@ -472,7 +472,7 @@ jax_test(
|
|||||||
] + py_deps("numpy"),
|
] + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_paged_attention_kernel_test",
|
name = "tpu_paged_attention_kernel_test",
|
||||||
srcs = ["tpu_paged_attention_kernel_test.py"],
|
srcs = ["tpu_paged_attention_kernel_test.py"],
|
||||||
disable_backends = [
|
disable_backends = [
|
||||||
@ -490,7 +490,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_splash_attention_kernel_test",
|
name = "tpu_splash_attention_kernel_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tpu_splash_attention_kernel_test.py",
|
"tpu_splash_attention_kernel_test.py",
|
||||||
@ -510,7 +510,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "tpu_splash_attention_mask_test",
|
name = "tpu_splash_attention_mask_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tpu_splash_attention_mask_test.py",
|
"tpu_splash_attention_mask_test.py",
|
||||||
@ -523,7 +523,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "gpu_attention_test",
|
name = "gpu_attention_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gpu_attention_test.py",
|
"gpu_attention_test.py",
|
||||||
@ -556,7 +556,7 @@ jax_test(
|
|||||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
jax_test(
|
jax_multiplatform_test(
|
||||||
name = "gpu_ops_test",
|
name = "gpu_ops_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"gpu_ops_test.py",
|
"gpu_ops_test.py",
|
||||||
|
@ -3209,8 +3209,12 @@ class EagerPmapMixin:
|
|||||||
self.jit_disabled = config.disable_jit.value
|
self.jit_disabled = config.disable_jit.value
|
||||||
config.update('jax_disable_jit', True)
|
config.update('jax_disable_jit', True)
|
||||||
config.update('jax_eager_pmap', True)
|
config.update('jax_eager_pmap', True)
|
||||||
|
self.warning_ctx = jtu.ignore_warning(
|
||||||
|
message="Some donated buffers were not usable", category=UserWarning)
|
||||||
|
self.warning_ctx.__enter__()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
self.warning_ctx.__exit__(None, None, None)
|
||||||
config.update('jax_eager_pmap', self.eager_pmap_enabled)
|
config.update('jax_eager_pmap', self.eager_pmap_enabled)
|
||||||
config.update('jax_disable_jit', self.jit_disabled)
|
config.update('jax_disable_jit', self.jit_disabled)
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
@ -108,6 +108,8 @@ class DLPackTest(jtu.JaxTestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertAllClose(np, y.cpu().numpy())
|
self.assertAllClose(np, y.cpu().numpy())
|
||||||
|
|
||||||
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||||
|
category=DeprecationWarning)
|
||||||
def testTorchToJaxInt64(self):
|
def testTorchToJaxInt64(self):
|
||||||
# See https://github.com/jax-ml/jax/issues/11895
|
# See https://github.com/jax-ml/jax/issues/11895
|
||||||
x = jax.dlpack.from_dlpack(
|
x = jax.dlpack.from_dlpack(
|
||||||
@ -116,6 +118,8 @@ class DLPackTest(jtu.JaxTestCase):
|
|||||||
self.assertEqual(x.dtype, dtype_expected)
|
self.assertEqual(x.dtype, dtype_expected)
|
||||||
|
|
||||||
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
|
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
|
||||||
|
@jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor",
|
||||||
|
category=DeprecationWarning)
|
||||||
def testTorchToJax(self, shape, dtype):
|
def testTorchToJax(self, shape, dtype):
|
||||||
if not config.enable_x64.value and dtype in [
|
if not config.enable_x64.value and dtype in [
|
||||||
jnp.int64,
|
jnp.int64,
|
||||||
|
@ -973,6 +973,7 @@ class BCOOTest(sptu.SparseTestCase):
|
|||||||
self.assertArraysAllClose(out.todense(), expected_out)
|
self.assertArraysAllClose(out.todense(), expected_out)
|
||||||
self.assertEqual(out.nse, expected_nse)
|
self.assertEqual(out.nse, expected_nse)
|
||||||
|
|
||||||
|
@jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available")
|
||||||
def test_bcoo_spdot_general_ad_bug(self):
|
def test_bcoo_spdot_general_ad_bug(self):
|
||||||
# Regression test for https://github.com/jax-ml/jax/issues/10163
|
# Regression test for https://github.com/jax-ml/jax/issues/10163
|
||||||
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])
|
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user