mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
ENH: Update numpy exceptions imports
This commit is contained in:
parent
a80d952680
commit
d183a2c02f
@ -70,7 +70,7 @@ from jax._src.lib.mlir.dialects import hlo
|
|||||||
from jax._src.sharding_impls import PmapSharding
|
from jax._src.sharding_impls import PmapSharding
|
||||||
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape
|
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike, Shape
|
||||||
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
|
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
|
||||||
split_list)
|
split_list, NumpyComplexWarning)
|
||||||
|
|
||||||
xb = xla_bridge
|
xb = xla_bridge
|
||||||
xc = xla_client
|
xc = xla_client
|
||||||
@ -537,7 +537,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N
|
|||||||
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||||
msg = "Casting complex values to real discards the imaginary part"
|
msg = "Casting complex values to real discards the imaginary part"
|
||||||
warnings.warn(msg, np.ComplexWarning, stacklevel=2)
|
warnings.warn(msg, NumpyComplexWarning, stacklevel=2)
|
||||||
|
|
||||||
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
|
# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
|
||||||
# not be an error since the target dtype fits the value. Handle this case by
|
# not be an error since the target dtype fits the value. Handle this case by
|
||||||
|
@ -64,7 +64,8 @@ from jax._src.numpy.vectorize import vectorize
|
|||||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
|
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
|
||||||
from jax._src.util import (unzip2, subvals, safe_zip,
|
from jax._src.util import (unzip2, subvals, safe_zip,
|
||||||
ceil_of_ratio, partition_list,
|
ceil_of_ratio, partition_list,
|
||||||
canonicalize_axis as _canonicalize_axis)
|
canonicalize_axis as _canonicalize_axis,
|
||||||
|
NumpyComplexWarning)
|
||||||
|
|
||||||
newaxis = None
|
newaxis = None
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@ -206,7 +207,7 @@ can_cast = dtypes.can_cast
|
|||||||
issubsctype = dtypes.issubsctype
|
issubsctype = dtypes.issubsctype
|
||||||
promote_types = dtypes.promote_types
|
promote_types = dtypes.promote_types
|
||||||
|
|
||||||
ComplexWarning = np.ComplexWarning
|
ComplexWarning = NumpyComplexWarning
|
||||||
|
|
||||||
array_str = np.array_str
|
array_str = np.array_str
|
||||||
array_repr = np.array_repr
|
array_repr = np.array_repr
|
||||||
|
@ -33,7 +33,8 @@ from jax._src.numpy.util import (
|
|||||||
from jax._src.lax import lax as lax_internal
|
from jax._src.lax import lax as lax_internal
|
||||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||||
from jax._src.util import (
|
from jax._src.util import (
|
||||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
|
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
|
||||||
|
NumpyComplexWarning)
|
||||||
|
|
||||||
|
|
||||||
_all = builtins.all
|
_all = builtins.all
|
||||||
@ -173,7 +174,7 @@ def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray:
|
|||||||
|
|
||||||
def _cast_to_bool(operand: ArrayLike) -> Array:
|
def _cast_to_bool(operand: ArrayLike) -> Array:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=np.ComplexWarning)
|
warnings.filterwarnings("ignore", category=NumpyComplexWarning)
|
||||||
return lax.convert_element_type(operand, np.bool_)
|
return lax.convert_element_type(operand, np.bool_)
|
||||||
|
|
||||||
def _cast_to_numeric(operand: ArrayLike) -> Array:
|
def _cast_to_numeric(operand: ArrayLike) -> Array:
|
||||||
|
@ -584,3 +584,11 @@ else:
|
|||||||
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
|
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
|
||||||
)
|
)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# numpy 1.25.0 or newer
|
||||||
|
NumpyComplexWarning = np.exceptions.ComplexWarning
|
||||||
|
except AttributeError:
|
||||||
|
# legacy numpy
|
||||||
|
NumpyComplexWarning = np.ComplexWarning
|
||||||
|
@ -28,6 +28,7 @@ import jax
|
|||||||
from jax import dtypes
|
from jax import dtypes
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
|
from jax._src.util import NumpyComplexWarning
|
||||||
from jax.test_util import check_grads
|
from jax.test_util import check_grads
|
||||||
|
|
||||||
from jax import config
|
from jax import config
|
||||||
@ -242,7 +243,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
|||||||
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
|
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
|
||||||
args = (rng((2, 3), from_dtype),)
|
args = (rng((2, 3), from_dtype),)
|
||||||
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
|
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
|
||||||
convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)(
|
convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)(
|
||||||
convert_element_type)
|
convert_element_type)
|
||||||
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
|
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ from jax import ops
|
|||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
|
from jax._src.util import NumpyComplexWarning
|
||||||
from jax._src.lax import lax as lax_internal
|
from jax._src.lax import lax as lax_internal
|
||||||
|
|
||||||
from jax import config
|
from jax import config
|
||||||
@ -1042,7 +1043,7 @@ class IndexingTest(jtu.JaxTestCase):
|
|||||||
out = x.at[0].set(y)
|
out = x.at[0].set(y)
|
||||||
self.assertEqual(x.dtype, out.dtype)
|
self.assertEqual(x.dtype, out.dtype)
|
||||||
|
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning,
|
@jtu.ignore_warning(category=NumpyComplexWarning,
|
||||||
message="Casting complex values to real")
|
message="Casting complex values to real")
|
||||||
def _check_warns(x_type, y_type, msg):
|
def _check_warns(x_type, y_type, msg):
|
||||||
with self.assertWarnsRegex(FutureWarning, msg):
|
with self.assertWarnsRegex(FutureWarning, msg):
|
||||||
|
@ -28,6 +28,7 @@ from jax import numpy as jnp
|
|||||||
|
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
|
from jax._src.util import NumpyComplexWarning
|
||||||
|
|
||||||
from jax import config
|
from jax import config
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
@ -208,7 +209,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
np_op = getattr(np, name)
|
np_op = getattr(np, name)
|
||||||
jnp_op = getattr(jnp, name)
|
jnp_op = getattr(jnp, name)
|
||||||
rng = rng_factory(self.rng())
|
rng = rng_factory(self.rng())
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="mean of empty slice.*")
|
message="mean of empty slice.*")
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
@ -307,7 +308,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="Degrees of freedom <= 0 for slice.*")
|
message="Degrees of freedom <= 0 for slice.*")
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
if inexact:
|
if inexact:
|
||||||
@ -347,7 +348,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
rng_factory.__name__ == 'rand_some_nan')
|
rng_factory.__name__ == 'rand_some_nan')
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="Degrees of freedom <= 0 for slice.*")
|
message="Degrees of freedom <= 0 for slice.*")
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
if inexact:
|
if inexact:
|
||||||
@ -384,7 +385,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="Degrees of freedom <= 0 for slice.*")
|
message="Degrees of freedom <= 0 for slice.*")
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
if inexact:
|
if inexact:
|
||||||
@ -430,7 +431,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="Degrees of freedom <= 0 for slice.*")
|
message="Degrees of freedom <= 0 for slice.*")
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
if inexact:
|
if inexact:
|
||||||
@ -473,7 +474,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
message="Mean of empty slice.*")
|
message="Mean of empty slice.*")
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="invalid value encountered.*")
|
message="invalid value encountered.*")
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
if inexact:
|
if inexact:
|
||||||
@ -551,7 +552,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="Degrees of freedom <= 0 for slice.")
|
message="Degrees of freedom <= 0 for slice.")
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
# Numpy fails with bfloat16 inputs
|
# Numpy fails with bfloat16 inputs
|
||||||
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||||
@ -583,7 +584,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
|||||||
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
|
||||||
@jtu.ignore_warning(category=RuntimeWarning,
|
@jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="Degrees of freedom <= 0 for slice.")
|
message="Degrees of freedom <= 0 for slice.")
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def np_fun(x):
|
def np_fun(x):
|
||||||
# Numpy fails with bfloat16 inputs
|
# Numpy fails with bfloat16 inputs
|
||||||
out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
|
||||||
|
@ -48,7 +48,7 @@ from jax._src import dtypes
|
|||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.lax import lax as lax_internal
|
from jax._src.lax import lax as lax_internal
|
||||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
||||||
from jax._src.util import safe_zip
|
from jax._src.util import safe_zip, NumpyComplexWarning
|
||||||
from jax._src import array
|
from jax._src import array
|
||||||
|
|
||||||
from jax import config
|
from jax import config
|
||||||
@ -1960,7 +1960,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
np_op = getattr(np, op)
|
np_op = getattr(np, op)
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype)
|
np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype)
|
||||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
|
||||||
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="overflow encountered.*")(np_fun)
|
message="overflow encountered.*")(np_fun)
|
||||||
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
|
jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype)
|
||||||
@ -1988,7 +1988,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
np_op = getattr(np, op)
|
np_op = getattr(np, op)
|
||||||
rng = jtu.rand_some_nan(self.rng())
|
rng = jtu.rand_some_nan(self.rng())
|
||||||
np_fun = partial(np_op, axis=axis, dtype=out_dtype)
|
np_fun = partial(np_op, axis=axis, dtype=out_dtype)
|
||||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun)
|
||||||
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
np_fun = jtu.ignore_warning(category=RuntimeWarning,
|
||||||
message="overflow encountered.*")(np_fun)
|
message="overflow encountered.*")(np_fun)
|
||||||
jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype)
|
jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype)
|
||||||
|
@ -46,6 +46,7 @@ from jax._src.lax import lax as lax_internal
|
|||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import xla_extension_version
|
from jax._src.lib import xla_extension_version
|
||||||
from jax._src.internal_test_util import lax_test_util
|
from jax._src.internal_test_util import lax_test_util
|
||||||
|
from jax._src.util import NumpyComplexWarning
|
||||||
|
|
||||||
from jax import config
|
from jax import config
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
@ -2770,7 +2771,7 @@ class LazyConstantTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes)
|
dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes)
|
||||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
@jtu.ignore_warning(category=NumpyComplexWarning)
|
||||||
def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out):
|
def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out):
|
||||||
x = jax.device_put(np.zeros(5, dtype_in))
|
x = jax.device_put(np.zeros(5, dtype_in))
|
||||||
self.assertEqual(x.dtype, dtype_in)
|
self.assertEqual(x.dtype, dtype_in)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user