Updated jnp.ceil/floor/trunc to preserve int dtypes

Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
  - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
This commit is contained in:
vfdev-5 2024-05-27 10:51:13 +00:00
parent a4c92a454b
commit 70b4823348
7 changed files with 16 additions and 31 deletions

View File

@ -10,6 +10,8 @@ Remember to align the itemized text with the first line of an item within a list
* Changes
* The minimum NumPy version is now 1.24.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
## jaxlib 0.4.31

View File

@ -376,6 +376,8 @@ def result_type(*args: Any) -> DType:
@jit
def trunc(x: ArrayLike) -> Array:
util.check_arraylike('trunc', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax_internal.asarray(x)
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))

View File

@ -92,11 +92,17 @@ def sign(x: ArrayLike, /) -> Array:
@implements(np.floor, module='numpy')
@partial(jit, inline=True)
def floor(x: ArrayLike, /) -> Array:
check_arraylike('floor', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax.asarray(x)
return lax.floor(*promote_args_inexact('floor', x))
@implements(np.ceil, module='numpy')
@partial(jit, inline=True)
def ceil(x: ArrayLike, /) -> Array:
check_arraylike('ceil', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax.asarray(x)
return lax.ceil(*promote_args_inexact('ceil', x))
@implements(np.exp, module='numpy')

View File

@ -66,6 +66,7 @@ from jax.numpy import (
broadcast_arrays as broadcast_arrays,
broadcast_to as broadcast_to,
can_cast as can_cast,
ceil as ceil,
complex128 as complex128,
complex64 as complex64,
concat as concat,
@ -85,6 +86,7 @@ from jax.numpy import (
flip as flip,
float32 as float32,
float64 as float64,
floor as floor,
floor_divide as floor_divide,
from_dlpack as from_dlpack,
full as full,
@ -160,6 +162,7 @@ from jax.numpy import (
tile as tile,
tril as tril,
triu as triu,
trunc as trunc,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
@ -192,11 +195,8 @@ from jax.experimental.array_api._data_type_functions import (
)
from jax.experimental.array_api._elementwise_functions import (
ceil as ceil,
clip as clip,
floor as floor,
hypot as hypot,
trunc as trunc,
)
from jax.experimental.array_api._statistical_functions import (

View File

@ -18,15 +18,6 @@ from jax._src.dtypes import issubdtype
from jax._src.numpy.util import promote_args
# TODO(micky774): Update jnp.ceil to preserve integral dtype
def ceil(x, /):
"""Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i."""
x, = promote_args("ceil", x)
if isdtype(x.dtype, "integral"):
return x
return jax.numpy.ceil(x)
# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
def clip(x, /, min=None, max=None):
@ -43,15 +34,6 @@ def clip(x, /, min=None, max=None):
return jax.numpy.clip(x, min=min, max=max)
# TODO(micky774): Update jnp.floor to preserve integral dtype
def floor(x, /):
"""Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i."""
x, = promote_args("floor", x)
if isdtype(x.dtype, "integral"):
return x
return jax.numpy.floor(x)
# TODO(micky774): Remove when jnp.hypot deprecation is completed
# (began 2024-4-14) and default behavior is Array API 2023 compliant
def hypot(x1, x2, /):
@ -64,12 +46,3 @@ def hypot(x1, x2, /):
"values first, such as by using jnp.real or jnp.imag to take the real "
"or imaginary components respectively.")
return jax.numpy.hypot(x1, x2)
# TODO(micky774): Update jnp.trunc to preserve integral dtype
def trunc(x, /):
"""Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i."""
x, = promote_args("trunc", x)
if isdtype(x.dtype, "integral"):
return x
return jax.numpy.trunc(x)

View File

@ -243,6 +243,8 @@ class JetTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_ceil(self): self.unary_check(jnp.ceil)
@jtu.skip_on_devices("tpu")
def test_trunc(self): self.unary_check(jnp.trunc)
@jtu.skip_on_devices("tpu")
def test_round(self): self.unary_check(lax.round)
@jtu.skip_on_devices("tpu")
def test_sign(self): self.unary_check(lax.sign)

View File

@ -1080,7 +1080,7 @@ class PallasOpsTest(PallasTest):
[jnp.abs, jnp.negative],
["int16", "int32", "int64", "float16", "float32", "float64"],
),
([jnp.ceil, jnp.floor], ["float32", "float64"]),
([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]),
(
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
["float16", "float32", "float64"],