mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fixed hypot bug on nan/inf pairings, began deprecation of non-real values
This commit is contained in:
parent
51352fa05c
commit
2899213efb
@ -30,6 +30,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
|
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
|
||||||
now leads to an error rather than a warning.
|
now leads to an error rather than a warning.
|
||||||
* The minimum jaxlib version is now 0.4.23.
|
* The minimum jaxlib version is now 0.4.23.
|
||||||
|
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
|
||||||
|
passing complex-valued inputs to it. This will raise an error when the
|
||||||
|
deprecation is completed.
|
||||||
|
|
||||||
## jaxlib 0.4.27
|
## jaxlib 0.4.27
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from functools import partial
|
|||||||
import operator
|
import operator
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, overload
|
from typing import Any, Callable, overload
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -730,12 +731,22 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
|||||||
@implements(np.hypot, module='numpy')
|
@implements(np.hypot, module='numpy')
|
||||||
@jit
|
@jit
|
||||||
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||||
check_arraylike("hypot", x1, x2)
|
x1, x2 = promote_args_inexact("hypot", x1, x2)
|
||||||
x1, x2 = promote_dtypes_inexact(x1, x2)
|
|
||||||
x1 = lax.abs(x1)
|
# TODO(micky774): Promote to ValueError when deprecation is complete
|
||||||
x2 = lax.abs(x2)
|
# (began 2024-4-14).
|
||||||
|
if dtypes.issubdtype(x1.dtype, np.complexfloating):
|
||||||
|
warnings.warn(
|
||||||
|
"Passing complex-valued inputs to hypot is deprecated and will raise a "
|
||||||
|
"ValueError in the future. Please convert to real values first, such as "
|
||||||
|
"by using jnp.real or jnp.imag to take the real or imaginary components "
|
||||||
|
"respectively.",
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
x1, x2 = lax.abs(x1), lax.abs(x2)
|
||||||
|
idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2))
|
||||||
x1, x2 = maximum(x1, x2), minimum(x1, x2)
|
x1, x2 = maximum(x1, x2), minimum(x1, x2)
|
||||||
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
|
x = _where(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, _where(x1 == 0, lax._ones(x1), x1)))))
|
||||||
|
return _where(idx_inf, _lax_const(x, np.inf), x)
|
||||||
|
|
||||||
|
|
||||||
@implements(np.reciprocal, module='numpy')
|
@implements(np.reciprocal, module='numpy')
|
||||||
|
@ -124,6 +124,7 @@ from jax.experimental.array_api._elementwise_functions import (
|
|||||||
floor_divide as floor_divide,
|
floor_divide as floor_divide,
|
||||||
greater as greater,
|
greater as greater,
|
||||||
greater_equal as greater_equal,
|
greater_equal as greater_equal,
|
||||||
|
hypot as hypot,
|
||||||
imag as imag,
|
imag as imag,
|
||||||
isfinite as isfinite,
|
isfinite as isfinite,
|
||||||
isinf as isinf,
|
isinf as isinf,
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
from jax._src.dtypes import issubdtype
|
||||||
from jax.experimental.array_api._data_type_functions import (
|
from jax.experimental.array_api._data_type_functions import (
|
||||||
result_type as _result_type,
|
result_type as _result_type,
|
||||||
isdtype as _isdtype,
|
isdtype as _isdtype,
|
||||||
@ -210,6 +211,20 @@ def greater_equal(x1, x2, /):
|
|||||||
return jax.numpy.greater_equal(x1, x2)
|
return jax.numpy.greater_equal(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
|
def hypot(x1, x2, /):
|
||||||
|
"""Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
|
||||||
|
x1, x2 = _promote_dtypes("hypot", x1, x2)
|
||||||
|
|
||||||
|
# TODO(micky774): Remove when jnp.hypot deprecation is completed
|
||||||
|
# (began 2024-4-14) and default behavior is Array API 2023 compliant
|
||||||
|
if issubdtype(x1.dtype, jax.numpy.complexfloating):
|
||||||
|
raise ValueError(
|
||||||
|
"hypot does not support complex-valued inputs. Please convert to real "
|
||||||
|
"values first, such as by using jnp.real or jnp.imag to take the real "
|
||||||
|
"or imaginary components respectively.")
|
||||||
|
return jax.numpy.hypot(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
def imag(x, /):
|
def imag(x, /):
|
||||||
"""Returns the imaginary component of a complex number for each element x_i of the input array x."""
|
"""Returns the imaginary component of a complex number for each element x_i of the input array x."""
|
||||||
x, = _promote_dtypes("imag", x)
|
x, = _promote_dtypes("imag", x)
|
||||||
|
@ -88,6 +88,7 @@ MAIN_NAMESPACE = {
|
|||||||
'full_like',
|
'full_like',
|
||||||
'greater',
|
'greater',
|
||||||
'greater_equal',
|
'greater_equal',
|
||||||
|
'hypot',
|
||||||
'iinfo',
|
'iinfo',
|
||||||
'imag',
|
'imag',
|
||||||
'inf',
|
'inf',
|
||||||
|
@ -221,7 +221,7 @@ JAX_COMPOUND_OP_RECORDS = [
|
|||||||
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
|
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||||
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
||||||
inexact=True),
|
inexact=True),
|
||||||
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
op_record("hypot", 2, real_dtypes, all_shapes, jtu.rand_default, [],
|
||||||
inexact=True),
|
inexact=True),
|
||||||
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||||
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||||
|
@ -914,6 +914,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
jnp.clip(x, max=jnp.array([-1+5j]))
|
jnp.clip(x, max=jnp.array([-1+5j]))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(micky774): Check for ValueError instead of DeprecationWarning when
|
||||||
|
# jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is
|
||||||
|
# Array API 2023 compliant
|
||||||
|
@jtu.sample_product(shape=all_shapes)
|
||||||
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||||
|
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
||||||
|
def testHypotComplexInputDeprecation(self, shape):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
x = rng(shape, dtype=jnp.complex64)
|
||||||
|
msg = "Passing complex-valued inputs to hypot"
|
||||||
|
# jit is disabled so we don't miss warnings due to caching.
|
||||||
|
with jax.disable_jit():
|
||||||
|
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||||
|
jnp.hypot(x, x)
|
||||||
|
|
||||||
|
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||||
|
y = jnp.ones_like(x)
|
||||||
|
jnp.hypot(x, y)
|
||||||
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[dict(shape=shape, dtype=dtype)
|
[dict(shape=shape, dtype=dtype)
|
||||||
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],
|
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user