mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
ENH: Update numpy exceptions imports
This commit is contained in:
parent
a80d952680
commit
d183a2c02f
@ -70,7 +70,7 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape
|
||||
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
|
||||
split_list)
|
||||
split_list, NumpyComplexWarning)
|
||||
|
||||
xb = xla_bridge
|
||||
xc = xla_client
|
||||
@ -537,7 +537,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N
|
||||
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
msg = "Casting complex values to real discards the imaginary part"
|
||||
warnings.warn(msg, np.ComplexWarning, stacklevel=2)
|
||||
warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
|
||||
|
||||
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
|
||||
# not be an error since the target dtype fits the value. Handle this case by
|
||||
|
@ -64,7 +64,8 @@ from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
|
||||
from jax._src.util import (unzip2, subvals, safe_zip,
|
||||
ceil_of_ratio, partition_list,
|
||||
canonicalize_axis as _canonicalize_axis)
|
||||
canonicalize_axis as _canonicalize_axis,
|
||||
NumpyComplexWarning)
|
||||
|
||||
newaxis = None
|
||||
T = TypeVar('T')
|
||||
@ -206,7 +207,7 @@ can_cast = dtypes.can_cast
|
||||
issubsctype = dtypes.issubsctype
|
||||
promote_types = dtypes.promote_types
|
||||
|
||||
ComplexWarning = np.ComplexWarning
|
||||
ComplexWarning = NumpyComplexWarning
|
||||
|
||||
array_str = np.array_str
|
||||
array_repr = np.array_repr
|
||||
|
@ -33,7 +33,8 @@ from jax._src.numpy.util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||
from jax._src.util import (
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
|
||||
NumpyComplexWarning)
|
||||
|
||||
|
||||
_all = builtins.all
|
||||
@ -173,7 +174,7 @@ def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
|
||||
|
||||
def _cast_to_bool(operand: ArrayLike) -> Array:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=np.ComplexWarning)
|
||||
warnings.filterwarnings("ignore", category=NumpyComplexWarning)
|
||||
return lax.convert_element_type(operand, np.bool_)
|
||||
|
||||
def _cast_to_numeric(operand: ArrayLike) -> Array:
|
||||
|
@ -584,3 +584,11 @@ else:
|
||||
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
try:
|
||||
# numpy 1.25.0 or newer
|
||||
NumpyComplexWarning = np.exceptions.ComplexWarning
|
||||
except AttributeError:
|
||||
# legacy numpy
|
||||
NumpyComplexWarning = np.ComplexWarning
|
||||
|
@ -28,6 +28,7 @@ import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
from jax.test_util import check_grads
|
||||
|
||||
from jax import config
|
||||
@ -242,7 +243,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
|
||||
args = (rng((2, 3), from_dtype),)
|
||||
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
|
||||
convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)(
|
||||
convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)(
|
||||
convert_element_type)
|
||||
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
|
||||
|
||||
|
@ -33,6 +33,7 @@ from jax import ops
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
||||
from jax import config
|
||||
@ -1042,7 +1043,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
out = x.at[0].set(y)
|
||||
self.assertEqual(x.dtype, out.dtype)
|
||||
|
||||
@jtu.ignore_warning(category=np.ComplexWarning,
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning,
|
||||
message="Casting complex values to real")
|
||||
def _check_warns(x_type, y_type, msg):
|
||||
with self.assertWarnsRegex(FutureWarning, msg):
|
||||
|
@ -28,6 +28,7 @@ from jax import numpy as jnp
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -208,7 +209,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
np_op = getattr(np, name)
|
||||
jnp_op = getattr(jnp, name)
|
||||
rng = rng_factory(self.rng())
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="mean of empty slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
@ -307,7 +308,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
@ -347,7 +348,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
rng_factory.__name__ == 'rand_some_nan')
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
@ -384,7 +385,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
@ -430,7 +431,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
@ -473,7 +474,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
message="Mean of empty slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value encountered.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
x = np.asarray(x)
|
||||
if inexact:
|
||||
@ -551,7 +552,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
# Numpy fails with bfloat16 inputs
|
||||
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||
@ -583,7 +584,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def np_fun(x):
|
||||
# Numpy fails with bfloat16 inputs
|
||||
out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||
|
@ -48,7 +48,7 @@ from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import safe_zip, NumpyComplexWarning
|
||||
from jax._src import array
|
||||
|
||||
from jax import config
|
||||
@ -1960,7 +1960,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_op = getattr(np, op)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype)
|
||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
||||
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
|
||||
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="overflow encountered.*")(np_fun)
|
||||
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
|
||||
@ -1988,7 +1988,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_op = getattr(np, op)
|
||||
rng = jtu.rand_some_nan(self.rng())
|
||||
np_fun = partial(np_op, axis=axis, dtype=out_dtype)
|
||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
||||
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
|
||||
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="overflow encountered.*")(np_fun)
|
||||
jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype)
|
||||
|
@ -46,6 +46,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -2770,7 +2771,7 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes)
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||
def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out):
|
||||
x = jax.device_put(np.zeros(5, dtype_in))
|
||||
self.assertEqual(x.dtype, dtype_in)
|
||||
|
Loading…
x
Reference in New Issue
Block a user