From d183a2c02f6254a8e8f7e926bf97c09fdfb8f9a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Mon, 7 Aug 2023 19:08:41 +0200 Subject: [PATCH] ENH: Update numpy exceptions imports --- jax/_src/lax/lax.py | 4 ++-- jax/_src/numpy/lax_numpy.py | 5 +++-- jax/_src/numpy/reductions.py | 5 +++-- jax/_src/util.py | 8 ++++++++ tests/lax_autodiff_test.py | 3 ++- tests/lax_numpy_indexing_test.py | 3 ++- tests/lax_numpy_reducers_test.py | 17 +++++++++-------- tests/lax_numpy_test.py | 6 +++--- tests/lax_test.py | 3 ++- 9 files changed, 34 insertions(+), 20 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1d429ec39..d5e4e3089 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -70,7 +70,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.sharding_impls import PmapSharding from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, - split_list) + split_list, NumpyComplexWarning) xb = xla_bridge xc = xla_client @@ -537,7 +537,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N if (dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" - warnings.warn(msg, np.ComplexWarning, stacklevel=2) + warnings.warn(msg, NumpyComplexWarning, stacklevel=2) # Python has big integers, but convert_element_type(2 ** 100, np.float32) need # not be an error since the target dtype fits the value. Handle this case by diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 73e716ddd..7e4e055d7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -64,7 +64,8 @@ from jax._src.numpy.vectorize import vectorize from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape from jax._src.util import (unzip2, subvals, safe_zip, ceil_of_ratio, partition_list, - canonicalize_axis as _canonicalize_axis) + canonicalize_axis as _canonicalize_axis, + NumpyComplexWarning) newaxis = None T = TypeVar('T') @@ -206,7 +207,7 @@ can_cast = dtypes.can_cast issubsctype = dtypes.issubsctype promote_types = dtypes.promote_types -ComplexWarning = np.ComplexWarning +ComplexWarning = NumpyComplexWarning array_str = np.array_str array_repr = np.array_repr diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 469f6fe1a..8521a651a 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -33,7 +33,8 @@ from jax._src.numpy.util import ( from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.util import ( - canonicalize_axis as _canonicalize_axis, maybe_named_axis) + canonicalize_axis as _canonicalize_axis, maybe_named_axis, + NumpyComplexWarning) _all = builtins.all @@ -173,7 +174,7 @@ def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray: def _cast_to_bool(operand: ArrayLike) -> Array: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=np.ComplexWarning) + warnings.filterwarnings("ignore", category=NumpyComplexWarning) return lax.convert_element_type(operand, np.bool_) def _cast_to_numeric(operand: ArrayLike) -> Array: diff --git a/jax/_src/util.py b/jax/_src/util.py index fd54afe8b..4603933d4 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -584,3 +584,11 @@ else: "Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)" ) return decorator + + +try: + # numpy 1.25.0 or newer + NumpyComplexWarning = np.exceptions.ComplexWarning +except AttributeError: + # legacy numpy + NumpyComplexWarning = np.ComplexWarning diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 569e106d7..feaae69bc 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -28,6 +28,7 @@ import jax from jax import dtypes from jax import lax from jax._src import test_util as jtu +from jax._src.util import NumpyComplexWarning from jax.test_util import check_grads from jax import config @@ -242,7 +243,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) - convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)( + convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index c76f90a21..91f66958c 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -33,6 +33,7 @@ from jax import ops from jax._src import dtypes from jax._src import test_util as jtu from jax._src import util +from jax._src.util import NumpyComplexWarning from jax._src.lax import lax as lax_internal from jax import config @@ -1042,7 +1043,7 @@ class IndexingTest(jtu.JaxTestCase): out = x.at[0].set(y) self.assertEqual(x.dtype, out.dtype) - @jtu.ignore_warning(category=np.ComplexWarning, + @jtu.ignore_warning(category=NumpyComplexWarning, message="Casting complex values to real") def _check_warns(x_type, y_type, msg): with self.assertWarnsRegex(FutureWarning, msg): diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 42564aa8d..5466dace2 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -28,6 +28,7 @@ from jax import numpy as jnp from jax._src import dtypes from jax._src import test_util as jtu +from jax._src.util import NumpyComplexWarning from jax import config config.parse_flags_with_absl() @@ -208,7 +209,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): np_op = getattr(np, name) jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) @jtu.ignore_warning(category=RuntimeWarning, message="mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, @@ -307,7 +308,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -347,7 +348,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): rng_factory.__name__ == 'rand_some_nan') @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -384,7 +385,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -430,7 +431,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): where = jtu.rand_bool(self.rng())(whereshape, np.bool_) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -473,7 +474,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered.*") - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -551,7 +552,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): # Numpy fails with bfloat16 inputs out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), @@ -583,7 +584,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): # Numpy fails with bfloat16 inputs out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 5e05f028a..838bbfcc2 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -48,7 +48,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps -from jax._src.util import safe_zip +from jax._src.util import safe_zip, NumpyComplexWarning from jax._src import array from jax import config @@ -1960,7 +1960,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_op = getattr(np, op) rng = jtu.rand_default(self.rng()) np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) @@ -1988,7 +1988,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): np_op = getattr(np, op) rng = jtu.rand_some_nan(self.rng()) np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) diff --git a/tests/lax_test.py b/tests/lax_test.py index 9c7c485ef..8f1aefdea 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -46,6 +46,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.internal_test_util import lax_test_util +from jax._src.util import NumpyComplexWarning from jax import config config.parse_flags_with_absl() @@ -2770,7 +2771,7 @@ class LazyConstantTest(jtu.JaxTestCase): @jtu.sample_product( dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes) - @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=NumpyComplexWarning) def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out): x = jax.device_put(np.zeros(5, dtype_in)) self.assertEqual(x.dtype, dtype_in)