Fixed hypot bug on nan/inf pairings, began deprecation of non-real values

This commit is contained in:
Meekail Zain 2024-04-15 17:56:16 +00:00
parent 51352fa05c
commit 2899213efb
7 changed files with 57 additions and 6 deletions

View File

@ -30,6 +30,9 @@ Remember to align the itemized text with the first line of an item within a list
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
now leads to an error rather than a warning.
* The minimum jaxlib version is now 0.4.23.
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
## jaxlib 0.4.27

View File

@ -22,6 +22,7 @@ from functools import partial
import operator
from textwrap import dedent
from typing import Any, Callable, overload
import warnings
import numpy as np
@ -730,12 +731,22 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@implements(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("hypot", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
x1 = lax.abs(x1)
x2 = lax.abs(x2)
x1, x2 = promote_args_inexact("hypot", x1, x2)
# TODO(micky774): Promote to ValueError when deprecation is complete
# (began 2024-4-14).
if dtypes.issubdtype(x1.dtype, np.complexfloating):
warnings.warn(
"Passing complex-valued inputs to hypot is deprecated and will raise a "
"ValueError in the future. Please convert to real values first, such as "
"by using jnp.real or jnp.imag to take the real or imaginary components "
"respectively.",
DeprecationWarning, stacklevel=2)
x1, x2 = lax.abs(x1), lax.abs(x2)
idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2))
x1, x2 = maximum(x1, x2), minimum(x1, x2)
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1)))))
x = _where(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, _where(x1 == 0, lax._ones(x1), x1)))))
return _where(idx_inf, _lax_const(x, np.inf), x)
@implements(np.reciprocal, module='numpy')

View File

@ -124,6 +124,7 @@ from jax.experimental.array_api._elementwise_functions import (
floor_divide as floor_divide,
greater as greater,
greater_equal as greater_equal,
hypot as hypot,
imag as imag,
isfinite as isfinite,
isinf as isinf,

View File

@ -13,6 +13,7 @@
# limitations under the License.
import jax
from jax._src.dtypes import issubdtype
from jax.experimental.array_api._data_type_functions import (
result_type as _result_type,
isdtype as _isdtype,
@ -210,6 +211,20 @@ def greater_equal(x1, x2, /):
return jax.numpy.greater_equal(x1, x2)
def hypot(x1, x2, /):
"""Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("hypot", x1, x2)
# TODO(micky774): Remove when jnp.hypot deprecation is completed
# (began 2024-4-14) and default behavior is Array API 2023 compliant
if issubdtype(x1.dtype, jax.numpy.complexfloating):
raise ValueError(
"hypot does not support complex-valued inputs. Please convert to real "
"values first, such as by using jnp.real or jnp.imag to take the real "
"or imaginary components respectively.")
return jax.numpy.hypot(x1, x2)
def imag(x, /):
"""Returns the imaginary component of a complex number for each element x_i of the input array x."""
x, = _promote_dtypes("imag", x)

View File

@ -88,6 +88,7 @@ MAIN_NAMESPACE = {
'full_like',
'greater',
'greater_equal',
'hypot',
'iinfo',
'imag',
'inf',

View File

@ -221,7 +221,7 @@ JAX_COMPOUND_OP_RECORDS = [
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
op_record("hypot", 2, real_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),

View File

@ -914,6 +914,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.clip(x, max=jnp.array([-1+5j]))
# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testHypotComplexInputDeprecation(self, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Passing complex-valued inputs to hypot"
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.hypot(x, x)
with self.assertWarns(DeprecationWarning, msg=msg):
y = jnp.ones_like(x)
jnp.hypot(x, y)
@jtu.sample_product(
[dict(shape=shape, dtype=dtype)
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],