Fixed hypot bug on nan/inf pairings, began deprecation of non-real values

This commit is contained in:
Meekail Zain 2024-04-15 17:56:16 +00:00
parent 51352fa05c
commit 2899213efb
7 changed files with 57 additions and 6 deletions

View File

@ -30,6 +30,9 @@ Remember to align the itemized text with the first line of an item within a list
* In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames` * In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
now leads to an error rather than a warning. now leads to an error rather than a warning.
* The minimum jaxlib version is now 0.4.23. * The minimum jaxlib version is now 0.4.23.
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
## jaxlib 0.4.27 ## jaxlib 0.4.27

View File

@ -22,6 +22,7 @@ from functools import partial
import operator import operator
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, overload from typing import Any, Callable, overload
import warnings
import numpy as np import numpy as np
@ -730,12 +731,22 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@implements(np.hypot, module='numpy') @implements(np.hypot, module='numpy')
@jit @jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("hypot", x1, x2) x1, x2 = promote_args_inexact("hypot", x1, x2)
x1, x2 = promote_dtypes_inexact(x1, x2)
x1 = lax.abs(x1) # TODO(micky774): Promote to ValueError when deprecation is complete
x2 = lax.abs(x2) # (began 2024-4-14).
if dtypes.issubdtype(x1.dtype, np.complexfloating):
warnings.warn(
"Passing complex-valued inputs to hypot is deprecated and will raise a "
"ValueError in the future. Please convert to real values first, such as "
"by using jnp.real or jnp.imag to take the real or imaginary components "
"respectively.",
DeprecationWarning, stacklevel=2)
x1, x2 = lax.abs(x1), lax.abs(x2)
idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2))
x1, x2 = maximum(x1, x2), minimum(x1, x2) x1, x2 = maximum(x1, x2), minimum(x1, x2)
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1))))) x = _where(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, _where(x1 == 0, lax._ones(x1), x1)))))
return _where(idx_inf, _lax_const(x, np.inf), x)
@implements(np.reciprocal, module='numpy') @implements(np.reciprocal, module='numpy')

View File

@ -124,6 +124,7 @@ from jax.experimental.array_api._elementwise_functions import (
floor_divide as floor_divide, floor_divide as floor_divide,
greater as greater, greater as greater,
greater_equal as greater_equal, greater_equal as greater_equal,
hypot as hypot,
imag as imag, imag as imag,
isfinite as isfinite, isfinite as isfinite,
isinf as isinf, isinf as isinf,

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import jax import jax
from jax._src.dtypes import issubdtype
from jax.experimental.array_api._data_type_functions import ( from jax.experimental.array_api._data_type_functions import (
result_type as _result_type, result_type as _result_type,
isdtype as _isdtype, isdtype as _isdtype,
@ -210,6 +211,20 @@ def greater_equal(x1, x2, /):
return jax.numpy.greater_equal(x1, x2) return jax.numpy.greater_equal(x1, x2)
def hypot(x1, x2, /):
"""Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("hypot", x1, x2)
# TODO(micky774): Remove when jnp.hypot deprecation is completed
# (began 2024-4-14) and default behavior is Array API 2023 compliant
if issubdtype(x1.dtype, jax.numpy.complexfloating):
raise ValueError(
"hypot does not support complex-valued inputs. Please convert to real "
"values first, such as by using jnp.real or jnp.imag to take the real "
"or imaginary components respectively.")
return jax.numpy.hypot(x1, x2)
def imag(x, /): def imag(x, /):
"""Returns the imaginary component of a complex number for each element x_i of the input array x.""" """Returns the imaginary component of a complex number for each element x_i of the input array x."""
x, = _promote_dtypes("imag", x) x, = _promote_dtypes("imag", x)

View File

@ -88,6 +88,7 @@ MAIN_NAMESPACE = {
'full_like', 'full_like',
'greater', 'greater',
'greater_equal', 'greater_equal',
'hypot',
'iinfo', 'iinfo',
'imag', 'imag',
'inf', 'inf',

View File

@ -221,7 +221,7 @@ JAX_COMPOUND_OP_RECORDS = [
op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []), op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True), inexact=True),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], op_record("hypot", 2, real_dtypes, all_shapes, jtu.rand_default, [],
inexact=True), inexact=True),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []), op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []),
op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []),

View File

@ -914,6 +914,26 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.clip(x, max=jnp.array([-1+5j])) jnp.clip(x, max=jnp.array([-1+5j]))
# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testHypotComplexInputDeprecation(self, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Passing complex-valued inputs to hypot"
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.hypot(x, x)
with self.assertWarns(DeprecationWarning, msg=msg):
y = jnp.ones_like(x)
jnp.hypot(x, y)
@jtu.sample_product( @jtu.sample_product(
[dict(shape=shape, dtype=dtype) [dict(shape=shape, dtype=dtype)
for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)], for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],