mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Updated jnp.ceil/floor/trunc to preserve int dtypes
Description: - Updated jnp.ceil/floor/trunc to preserve int dtypes - Updated tests - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
This commit is contained in:
parent
a4c92a454b
commit
70b4823348
@ -10,6 +10,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
* Changes
|
||||
* The minimum NumPy version is now 1.24.
|
||||
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
|
||||
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
|
||||
|
||||
## jaxlib 0.4.31
|
||||
|
||||
|
@ -376,6 +376,8 @@ def result_type(*args: Any) -> DType:
|
||||
@jit
|
||||
def trunc(x: ArrayLike) -> Array:
|
||||
util.check_arraylike('trunc', x)
|
||||
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
|
||||
return lax_internal.asarray(x)
|
||||
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))
|
||||
|
||||
|
||||
|
@ -92,11 +92,17 @@ def sign(x: ArrayLike, /) -> Array:
|
||||
@implements(np.floor, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def floor(x: ArrayLike, /) -> Array:
|
||||
check_arraylike('floor', x)
|
||||
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
|
||||
return lax.asarray(x)
|
||||
return lax.floor(*promote_args_inexact('floor', x))
|
||||
|
||||
@implements(np.ceil, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def ceil(x: ArrayLike, /) -> Array:
|
||||
check_arraylike('ceil', x)
|
||||
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
|
||||
return lax.asarray(x)
|
||||
return lax.ceil(*promote_args_inexact('ceil', x))
|
||||
|
||||
@implements(np.exp, module='numpy')
|
||||
|
@ -66,6 +66,7 @@ from jax.numpy import (
|
||||
broadcast_arrays as broadcast_arrays,
|
||||
broadcast_to as broadcast_to,
|
||||
can_cast as can_cast,
|
||||
ceil as ceil,
|
||||
complex128 as complex128,
|
||||
complex64 as complex64,
|
||||
concat as concat,
|
||||
@ -85,6 +86,7 @@ from jax.numpy import (
|
||||
flip as flip,
|
||||
float32 as float32,
|
||||
float64 as float64,
|
||||
floor as floor,
|
||||
floor_divide as floor_divide,
|
||||
from_dlpack as from_dlpack,
|
||||
full as full,
|
||||
@ -160,6 +162,7 @@ from jax.numpy import (
|
||||
tile as tile,
|
||||
tril as tril,
|
||||
triu as triu,
|
||||
trunc as trunc,
|
||||
uint16 as uint16,
|
||||
uint32 as uint32,
|
||||
uint64 as uint64,
|
||||
@ -192,11 +195,8 @@ from jax.experimental.array_api._data_type_functions import (
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._elementwise_functions import (
|
||||
ceil as ceil,
|
||||
clip as clip,
|
||||
floor as floor,
|
||||
hypot as hypot,
|
||||
trunc as trunc,
|
||||
)
|
||||
|
||||
from jax.experimental.array_api._statistical_functions import (
|
||||
|
@ -18,15 +18,6 @@ from jax._src.dtypes import issubdtype
|
||||
from jax._src.numpy.util import promote_args
|
||||
|
||||
|
||||
# TODO(micky774): Update jnp.ceil to preserve integral dtype
|
||||
def ceil(x, /):
|
||||
"""Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i."""
|
||||
x, = promote_args("ceil", x)
|
||||
if isdtype(x.dtype, "integral"):
|
||||
return x
|
||||
return jax.numpy.ceil(x)
|
||||
|
||||
|
||||
# TODO(micky774): Remove when jnp.clip deprecation is completed
|
||||
# (began 2024-4-2) and default behavior is Array API 2023 compliant
|
||||
def clip(x, /, min=None, max=None):
|
||||
@ -43,15 +34,6 @@ def clip(x, /, min=None, max=None):
|
||||
return jax.numpy.clip(x, min=min, max=max)
|
||||
|
||||
|
||||
# TODO(micky774): Update jnp.floor to preserve integral dtype
|
||||
def floor(x, /):
|
||||
"""Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i."""
|
||||
x, = promote_args("floor", x)
|
||||
if isdtype(x.dtype, "integral"):
|
||||
return x
|
||||
return jax.numpy.floor(x)
|
||||
|
||||
|
||||
# TODO(micky774): Remove when jnp.hypot deprecation is completed
|
||||
# (began 2024-4-14) and default behavior is Array API 2023 compliant
|
||||
def hypot(x1, x2, /):
|
||||
@ -64,12 +46,3 @@ def hypot(x1, x2, /):
|
||||
"values first, such as by using jnp.real or jnp.imag to take the real "
|
||||
"or imaginary components respectively.")
|
||||
return jax.numpy.hypot(x1, x2)
|
||||
|
||||
|
||||
# TODO(micky774): Update jnp.trunc to preserve integral dtype
|
||||
def trunc(x, /):
|
||||
"""Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i."""
|
||||
x, = promote_args("trunc", x)
|
||||
if isdtype(x.dtype, "integral"):
|
||||
return x
|
||||
return jax.numpy.trunc(x)
|
||||
|
@ -243,6 +243,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ceil(self): self.unary_check(jnp.ceil)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_trunc(self): self.unary_check(jnp.trunc)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_round(self): self.unary_check(lax.round)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_sign(self): self.unary_check(lax.sign)
|
||||
|
@ -1080,7 +1080,7 @@ class PallasOpsTest(PallasTest):
|
||||
[jnp.abs, jnp.negative],
|
||||
["int16", "int32", "int64", "float16", "float32", "float64"],
|
||||
),
|
||||
([jnp.ceil, jnp.floor], ["float32", "float64"]),
|
||||
([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]),
|
||||
(
|
||||
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
|
||||
["float16", "float32", "float64"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user