mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #20754 from Micky774:array-api-hypot
PiperOrigin-RevId: 625035601
This commit is contained in:
commit
5f22b12576
@ -30,6 +30,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
|
||||
now leads to an error rather than a warning.
|
||||
* The minimum jaxlib version is now 0.4.23.
|
||||
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
|
||||
passing complex-valued inputs to it. This will raise an error when the
|
||||
deprecation is completed.
|
||||
|
||||
## jaxlib 0.4.27
|
||||
|
||||
|
@ -22,6 +22,7 @@ from functools import partial
|
||||
import operator
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, overload
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -730,12 +731,22 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
@implements(np.hypot, module='numpy')
|
||||
@jit
|
||||
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("hypot", x1, x2)
|
||||
x1, x2 = promote_dtypes_inexact(x1, x2)
|
||||
x1 = lax.abs(x1)
|
||||
x2 = lax.abs(x2)
|
||||
x1, x2 = promote_args_inexact("hypot", x1, x2)
|
||||
|
||||
# TODO(micky774): Promote to ValueError when deprecation is complete
|
||||
# (began 2024-4-14).
|
||||
if dtypes.issubdtype(x1.dtype, np.complexfloating):
|
||||
warnings.warn(
|
||||
"Passing complex-valued inputs to hypot is deprecated and will raise a "
|
||||
"ValueError in the future. Please convert to real values first, such as "
|
||||
"by using jnp.real or jnp.imag to take the real or imaginary components "
|
||||
"respectively.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
x1, x2 = lax.abs(x1), lax.abs(x2)
|
||||
idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2))
|
||||
x1, x2 = maximum(x1, x2), minimum(x1, x2)
|
||||
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
|
||||
x = _where(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, _where(x1 == 0, lax._ones(x1), x1)))))
|
||||
return _where(idx_inf, _lax_const(x, np.inf), x)
|
||||
|
||||
|
||||
@implements(np.reciprocal, module='numpy')
|
||||
|
@ -125,6 +125,7 @@ from jax.experimental.array_api._elementwise_functions import (
|
||||
floor_divide as floor_divide,
|
||||
greater as greater,
|
||||
greater_equal as greater_equal,
|
||||
hypot as hypot,
|
||||
imag as imag,
|
||||
isfinite as isfinite,
|
||||
isinf as isinf,
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
from jax._src.dtypes import issubdtype
|
||||
from jax.experimental.array_api._data_type_functions import (
|
||||
result_type as _result_type,
|
||||
isdtype as _isdtype,
|
||||
@ -214,6 +215,20 @@ def greater_equal(x1, x2, /):
|
||||
return jax.numpy.greater_equal(x1, x2)
|
||||
|
||||
|
||||
def hypot(x1, x2, /):
|
||||
"""Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
|
||||
x1, x2 = _promote_dtypes("hypot", x1, x2)
|
||||
|
||||
# TODO(micky774): Remove when jnp.hypot deprecation is completed
|
||||
# (began 2024-4-14) and default behavior is Array API 2023 compliant
|
||||
if issubdtype(x1.dtype, jax.numpy.complexfloating):
|
||||
raise ValueError(
|
||||
"hypot does not support complex-valued inputs. Please convert to real "
|
||||
"values first, such as by using jnp.real or jnp.imag to take the real "
|
||||
"or imaginary components respectively.")
|
||||
return jax.numpy.hypot(x1, x2)
|
||||
|
||||
|
||||
def imag(x, /):
|
||||
"""Returns the imaginary component of a complex number for each element x_i of the input array x."""
|
||||
x, = _promote_dtypes("imag", x)
|
||||
|
@ -89,6 +89,7 @@ MAIN_NAMESPACE = {
|
||||
'full_like',
|
||||
'greater',
|
||||
'greater_equal',
|
||||
'hypot',
|
||||
'iinfo',
|
||||
'imag',
|
||||
'inf',
|
||||
|
@ -221,7 +221,7 @@ JAX_COMPOUND_OP_RECORDS = [
|
||||
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
|
||||
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
||||
op_record("hypot", 2, real_dtypes, all_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
|
||||
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),
|
||||
|
@ -914,6 +914,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.clip(x, max=jnp.array([-1+5j]))
|
||||
|
||||
|
||||
# TODO(micky774): Check for ValueError instead of DeprecationWarning when
|
||||
# jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is
|
||||
# Array API 2023 compliant
|
||||
@jtu.sample_product(shape=all_shapes)
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
|
||||
def testHypotComplexInputDeprecation(self, shape):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype=jnp.complex64)
|
||||
msg = "Passing complex-valued inputs to hypot"
|
||||
# jit is disabled so we don't miss warnings due to caching.
|
||||
with jax.disable_jit():
|
||||
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||
jnp.hypot(x, x)
|
||||
|
||||
with self.assertWarns(DeprecationWarning, msg=msg):
|
||||
y = jnp.ones_like(x)
|
||||
jnp.hypot(x, y)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, dtype=dtype)
|
||||
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user