ENH: Update numpy exceptions imports

This commit is contained in:
Mateusz Sokół 2023-08-07 19:08:41 +02:00
parent a80d952680
commit d183a2c02f
9 changed files with 34 additions and 20 deletions

View File

@ -70,7 +70,7 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding_impls import PmapSharding from jax._src.sharding_impls import PmapSharding
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
split_list) split_list, NumpyComplexWarning)
xb = xla_bridge xb = xla_bridge
xc = xla_client xc = xla_client
@ -537,7 +537,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N
if (dtypes.issubdtype(old_dtype, np.complexfloating) and if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)): not dtypes.issubdtype(new_dtype, np.complexfloating)):
msg = "Casting complex values to real discards the imaginary part" msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2) warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need # Python has big integers, but convert_element_type(2 ** 100, np.float32) need
# not be an error since the target dtype fits the value. Handle this case by # not be an error since the target dtype fits the value. Handle this case by

View File

@ -64,7 +64,8 @@ from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip, from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list, ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis) canonicalize_axis as _canonicalize_axis,
NumpyComplexWarning)
newaxis = None newaxis = None
T = TypeVar('T') T = TypeVar('T')
@ -206,7 +207,7 @@ can_cast = dtypes.can_cast
issubsctype = dtypes.issubsctype issubsctype = dtypes.issubsctype
promote_types = dtypes.promote_types promote_types = dtypes.promote_types
ComplexWarning = np.ComplexWarning ComplexWarning = NumpyComplexWarning
array_str = np.array_str array_str = np.array_str
array_repr = np.array_repr array_repr = np.array_repr

View File

@ -33,7 +33,8 @@ from jax._src.numpy.util import (
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import ( from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis) canonicalize_axis as _canonicalize_axis, maybe_named_axis,
NumpyComplexWarning)
_all = builtins.all _all = builtins.all
@ -173,7 +174,7 @@ def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
def _cast_to_bool(operand: ArrayLike) -> Array: def _cast_to_bool(operand: ArrayLike) -> Array:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=np.ComplexWarning) warnings.filterwarnings("ignore", category=NumpyComplexWarning)
return lax.convert_element_type(operand, np.bool_) return lax.convert_element_type(operand, np.bool_)
def _cast_to_numeric(operand: ArrayLike) -> Array: def _cast_to_numeric(operand: ArrayLike) -> Array:

View File

@ -584,3 +584,11 @@ else:
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)" "Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
) )
return decorator return decorator
try:
# numpy 1.25.0 or newer
NumpyComplexWarning = np.exceptions.ComplexWarning
except AttributeError:
# legacy numpy
NumpyComplexWarning = np.ComplexWarning

View File

@ -28,6 +28,7 @@ import jax
from jax import dtypes from jax import dtypes
from jax import lax from jax import lax
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
from jax.test_util import check_grads from jax.test_util import check_grads
from jax import config from jax import config
@ -242,7 +243,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
args = (rng((2, 3), from_dtype),) args = (rng((2, 3), from_dtype),)
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)( convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)(
convert_element_type) convert_element_type)
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)

View File

@ -33,6 +33,7 @@ from jax import ops
from jax._src import dtypes from jax._src import dtypes
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import util from jax._src import util
from jax._src.util import NumpyComplexWarning
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax import config from jax import config
@ -1042,7 +1043,7 @@ class IndexingTest(jtu.JaxTestCase):
out = x.at[0].set(y) out = x.at[0].set(y)
self.assertEqual(x.dtype, out.dtype) self.assertEqual(x.dtype, out.dtype)
@jtu.ignore_warning(category=np.ComplexWarning, @jtu.ignore_warning(category=NumpyComplexWarning,
message="Casting complex values to real") message="Casting complex values to real")
def _check_warns(x_type, y_type, msg): def _check_warns(x_type, y_type, msg):
with self.assertWarnsRegex(FutureWarning, msg): with self.assertWarnsRegex(FutureWarning, msg):

View File

@ -28,6 +28,7 @@ from jax import numpy as jnp
from jax._src import dtypes from jax._src import dtypes
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
from jax import config from jax import config
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -208,7 +209,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
np_op = getattr(np, name) np_op = getattr(np, name)
jnp_op = getattr(jnp, name) jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng()) rng = rng_factory(self.rng())
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="mean of empty slice.*") message="mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
@ -307,7 +308,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*") message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x): def np_fun(x):
x = np.asarray(x) x = np.asarray(x)
if inexact: if inexact:
@ -347,7 +348,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
rng_factory.__name__ == 'rand_some_nan') rng_factory.__name__ == 'rand_some_nan')
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*") message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x): def np_fun(x):
x = np.asarray(x) x = np.asarray(x)
if inexact: if inexact:
@ -384,7 +385,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*") message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x): def np_fun(x):
x = np.asarray(x) x = np.asarray(x)
if inexact: if inexact:
@ -430,7 +431,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
where = jtu.rand_bool(self.rng())(whereshape, np.bool_) where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*") message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x): def np_fun(x):
x = np.asarray(x) x = np.asarray(x)
if inexact: if inexact:
@ -473,7 +474,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
message="Mean of empty slice.*") message="Mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered.*") message="invalid value encountered.*")
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x): def np_fun(x):
x = np.asarray(x) x = np.asarray(x)
if inexact: if inexact:
@ -551,7 +552,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, [shape], [dtype]) args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.") message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x): def np_fun(x):
# Numpy fails with bfloat16 inputs # Numpy fails with bfloat16 inputs
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
@ -583,7 +584,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
args_maker = self._GetArgsMaker(rng, [shape], [dtype]) args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning, @jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.") message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x): def np_fun(x):
# Numpy fails with bfloat16 inputs # Numpy fails with bfloat16 inputs
out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),

View File

@ -48,7 +48,7 @@ from jax._src import dtypes
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import safe_zip from jax._src.util import safe_zip, NumpyComplexWarning
from jax._src import array from jax._src import array
from jax import config from jax import config
@ -1960,7 +1960,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_op = getattr(np, op) np_op = getattr(np, op)
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
np_fun = jtu.ignore_warning(category=RuntimeWarning, np_fun = jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")(np_fun) message="overflow encountered.*")(np_fun)
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
@ -1988,7 +1988,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_op = getattr(np, op) np_op = getattr(np, op)
rng = jtu.rand_some_nan(self.rng()) rng = jtu.rand_some_nan(self.rng())
np_fun = partial(np_op, axis=axis, dtype=out_dtype) np_fun = partial(np_op, axis=axis, dtype=out_dtype)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
np_fun = jtu.ignore_warning(category=RuntimeWarning, np_fun = jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")(np_fun) message="overflow encountered.*")(np_fun)
jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype)

View File

@ -46,6 +46,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version from jax._src.lib import xla_extension_version
from jax._src.internal_test_util import lax_test_util from jax._src.internal_test_util import lax_test_util
from jax._src.util import NumpyComplexWarning
from jax import config from jax import config
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -2770,7 +2771,7 @@ class LazyConstantTest(jtu.JaxTestCase):
@jtu.sample_product( @jtu.sample_product(
dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes) dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes)
@jtu.ignore_warning(category=np.ComplexWarning) @jtu.ignore_warning(category=NumpyComplexWarning)
def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out): def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out):
x = jax.device_put(np.zeros(5, dtype_in)) x = jax.device_put(np.zeros(5, dtype_in))
self.assertEqual(x.dtype, dtype_in) self.assertEqual(x.dtype, dtype_in)