Implement jax.numpy.spacing

Somehow we've missed this numpy API up until now.
This commit is contained in:
Jake VanderPlas 2024-10-03 10:40:39 -07:00
parent 5800070c36
commit 635e29a0b9
5 changed files with 77 additions and 0 deletions

View File

@ -376,6 +376,7 @@ namespace; they are listed below.
size
sort
sort_complex
spacing
split
sqrt
square

View File

@ -1451,6 +1451,50 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
"""
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
@partial(jit, inline=True)
def spacing(x: ArrayLike, /) -> Array:
"""Return the spacing between ``x`` and the next adjacent number.
JAX implementation of :func:`numpy.spacing`.
Args:
x: real-valued array. Integer or boolean types will be cast to float.
Returns:
Array of same shape as ``x`` containing spacing between each entry of
``x`` and its closest adjacent value.
See also:
- :func:`jax.numpy.nextafter`: find the next representable value.
Examples:
>>> x = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0], dtype='float32')
>>> jnp.spacing(x)
Array([1.4012985e-45, 2.9802322e-08, 5.9604645e-08, 5.9604645e-08,
1.1920929e-07], dtype=float32)
For ``x = 1``, the spacing is equal to the ``eps`` value given by
:class:`jax.numpy.finfo`:
>>> x = jnp.float32(1)
>>> jnp.spacing(x) == jnp.finfo(x.dtype).eps
Array(True, dtype=bool)
"""
arr, = promote_args_inexact("spacing", x)
if dtypes.isdtype(arr.dtype, "complex floating"):
raise ValueError("jnp.spacing is not defined for complex inputs.")
inf = _lax_const(arr, np.inf)
smallest_subnormal = dtypes.finfo(arr.dtype).smallest_subnormal
# Numpy's behavior seems to depend on dtype
if arr.dtype == 'float16':
return lax.nextafter(arr, inf) - arr
else:
result = lax.nextafter(arr, copysign(inf, arr)) - arr
return _where(result == 0, copysign(smallest_subnormal, arr), result)
# Logical ops
@partial(jit, inline=True)
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:

View File

@ -443,6 +443,7 @@ from jax._src.numpy.ufuncs import (
sin as sin,
sinc as sinc,
sinh as sinh,
spacing as spacing,
sqrt as sqrt,
square as square,
subtract as subtract,

View File

@ -808,6 +808,7 @@ def sort(
order: None = ...,
) -> Array: ...
def sort_complex(a: ArrayLike) -> Array: ...
def spacing(x: ArrayLike, /) -> Array: ...
def split(
ary: ArrayLike,
indices_or_sections: int | Sequence[int] | ArrayLike,

View File

@ -126,6 +126,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("spacing", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance=0),
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
@ -701,6 +703,34 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
dx = jax.grad(jax.numpy.i0)(0.0)
self.assertArraysEqual(dx, 0.0)
@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes,
)
def testSpacingIntegerInputs(self, shape, dtype):
rng = jtu.rand_int(self.rng(), low=-64, high=64)
args_maker = lambda: [rng(shape, dtype)]
computation_dtype = jnp.spacing(rng(shape, dtype)).dtype
np_func = lambda x: np.spacing(np.array(x).astype(computation_dtype))
self._CheckAgainstNumpy(np_func, jnp.spacing, args_maker, check_dtypes=True, tol=0)
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
@jtu.sample_product(dtype = float_dtypes)
@jtu.skip_on_devices("tpu")
def testSpacingSubnormals(self, dtype):
zero = np.array(0, dtype=dtype)
inf = np.array(np.inf, dtype=dtype)
x = [zero]
for i in range(5):
x.append(np.nextafter(x[-1], -inf)) # negative denormals
x = x[::-1]
for i in range(5):
x.append(np.nextafter(x[-1], inf)) # positive denormals
x = np.array(x, dtype=dtype)
args_maker = lambda: [x]
self._CheckAgainstNumpy(np.spacing, jnp.spacing, args_maker, check_dtypes=True, tol=0)
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())