From 70b48233489fc3275e56c372e85ba632f9e21951 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 27 May 2024 10:51:13 +0000 Subject: [PATCH] Updated jnp.ceil/floor/trunc to preserve int dtypes Description: - Updated jnp.ceil/floor/trunc to preserve int dtypes - Updated tests - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case --- CHANGELOG.md | 2 ++ jax/_src/numpy/lax_numpy.py | 2 ++ jax/_src/numpy/ufuncs.py | 6 +++++ jax/experimental/array_api/__init__.py | 6 ++--- .../array_api/_elementwise_functions.py | 27 ------------------- tests/jet_test.py | 2 ++ tests/pallas/pallas_test.py | 2 +- 7 files changed, 16 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa15035fe..5f8aa043a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Remember to align the itemized text with the first line of an item within a list * Changes * The minimum NumPy version is now 1.24. + * {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output + of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point. ## jaxlib 0.4.31 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6b6671449..eeccd5863 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -376,6 +376,8 @@ def result_type(*args: Any) -> DType: @jit def trunc(x: ArrayLike) -> Array: util.check_arraylike('trunc', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax_internal.asarray(x) return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index ccc448e56..25467be2a 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -92,11 +92,17 @@ def sign(x: ArrayLike, /) -> Array: @implements(np.floor, module='numpy') @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: + check_arraylike('floor', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) return lax.floor(*promote_args_inexact('floor', x)) @implements(np.ceil, module='numpy') @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: + check_arraylike('ceil', x) + if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): + return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x)) @implements(np.exp, module='numpy') diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index f7375a80f..4240b9d51 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -66,6 +66,7 @@ from jax.numpy import ( broadcast_arrays as broadcast_arrays, broadcast_to as broadcast_to, can_cast as can_cast, + ceil as ceil, complex128 as complex128, complex64 as complex64, concat as concat, @@ -85,6 +86,7 @@ from jax.numpy import ( flip as flip, float32 as float32, float64 as float64, + floor as floor, floor_divide as floor_divide, from_dlpack as from_dlpack, full as full, @@ -160,6 +162,7 @@ from jax.numpy import ( tile as tile, tril as tril, triu as triu, + trunc as trunc, uint16 as uint16, uint32 as uint32, uint64 as uint64, @@ -192,11 +195,8 @@ from jax.experimental.array_api._data_type_functions import ( ) from jax.experimental.array_api._elementwise_functions import ( - ceil as ceil, clip as clip, - floor as floor, hypot as hypot, - trunc as trunc, ) from jax.experimental.array_api._statistical_functions import ( diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index 5587b9a60..103f8ab7d 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -18,15 +18,6 @@ from jax._src.dtypes import issubdtype from jax._src.numpy.util import promote_args -# TODO(micky774): Update jnp.ceil to preserve integral dtype -def ceil(x, /): - """Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i.""" - x, = promote_args("ceil", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.ceil(x) - - # TODO(micky774): Remove when jnp.clip deprecation is completed # (began 2024-4-2) and default behavior is Array API 2023 compliant def clip(x, /, min=None, max=None): @@ -43,15 +34,6 @@ def clip(x, /, min=None, max=None): return jax.numpy.clip(x, min=min, max=max) -# TODO(micky774): Update jnp.floor to preserve integral dtype -def floor(x, /): - """Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i.""" - x, = promote_args("floor", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.floor(x) - - # TODO(micky774): Remove when jnp.hypot deprecation is completed # (began 2024-4-14) and default behavior is Array API 2023 compliant def hypot(x1, x2, /): @@ -64,12 +46,3 @@ def hypot(x1, x2, /): "values first, such as by using jnp.real or jnp.imag to take the real " "or imaginary components respectively.") return jax.numpy.hypot(x1, x2) - - -# TODO(micky774): Update jnp.trunc to preserve integral dtype -def trunc(x, /): - """Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i.""" - x, = promote_args("trunc", x) - if isdtype(x.dtype, "integral"): - return x - return jax.numpy.trunc(x) diff --git a/tests/jet_test.py b/tests/jet_test.py index 79132174e..b1e2ef3f8 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -243,6 +243,8 @@ class JetTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") def test_ceil(self): self.unary_check(jnp.ceil) @jtu.skip_on_devices("tpu") + def test_trunc(self): self.unary_check(jnp.trunc) + @jtu.skip_on_devices("tpu") def test_round(self): self.unary_check(lax.round) @jtu.skip_on_devices("tpu") def test_sign(self): self.unary_check(lax.sign) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 39d638caf..5c659ee08 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1080,7 +1080,7 @@ class PallasOpsTest(PallasTest): [jnp.abs, jnp.negative], ["int16", "int32", "int64", "float16", "float32", "float64"], ), - ([jnp.ceil, jnp.floor], ["float32", "float64"]), + ([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]), ( [jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt], ["float16", "float32", "float64"],