Implement jax.numpy.spacing

Somehow we've missed this numpy API up until now.
This commit is contained in:
Jake VanderPlas 2024-10-03 10:40:39 -07:00
parent 5800070c36
commit 635e29a0b9
5 changed files with 77 additions and 0 deletions

View File

@ -376,6 +376,7 @@ namespace; they are listed below.
size size
sort sort
sort_complex sort_complex
spacing
split split
sqrt sqrt
square square

View File

@ -1451,6 +1451,50 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
""" """
return lax.nextafter(*promote_args_inexact("nextafter", x, y)) return lax.nextafter(*promote_args_inexact("nextafter", x, y))
@partial(jit, inline=True)
def spacing(x: ArrayLike, /) -> Array:
"""Return the spacing between ``x`` and the next adjacent number.
JAX implementation of :func:`numpy.spacing`.
Args:
x: real-valued array. Integer or boolean types will be cast to float.
Returns:
Array of same shape as ``x`` containing spacing between each entry of
``x`` and its closest adjacent value.
See also:
- :func:`jax.numpy.nextafter`: find the next representable value.
Examples:
>>> x = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0], dtype='float32')
>>> jnp.spacing(x)
Array([1.4012985e-45, 2.9802322e-08, 5.9604645e-08, 5.9604645e-08,
1.1920929e-07], dtype=float32)
For ``x = 1``, the spacing is equal to the ``eps`` value given by
:class:`jax.numpy.finfo`:
>>> x = jnp.float32(1)
>>> jnp.spacing(x) == jnp.finfo(x.dtype).eps
Array(True, dtype=bool)
"""
arr, = promote_args_inexact("spacing", x)
if dtypes.isdtype(arr.dtype, "complex floating"):
raise ValueError("jnp.spacing is not defined for complex inputs.")
inf = _lax_const(arr, np.inf)
smallest_subnormal = dtypes.finfo(arr.dtype).smallest_subnormal
# Numpy's behavior seems to depend on dtype
if arr.dtype == 'float16':
return lax.nextafter(arr, inf) - arr
else:
result = lax.nextafter(arr, copysign(inf, arr)) - arr
return _where(result == 0, copysign(smallest_subnormal, arr), result)
# Logical ops # Logical ops
@partial(jit, inline=True) @partial(jit, inline=True)
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:

View File

@ -443,6 +443,7 @@ from jax._src.numpy.ufuncs import (
sin as sin, sin as sin,
sinc as sinc, sinc as sinc,
sinh as sinh, sinh as sinh,
spacing as spacing,
sqrt as sqrt, sqrt as sqrt,
square as square, square as square,
subtract as subtract, subtract as subtract,

View File

@ -808,6 +808,7 @@ def sort(
order: None = ..., order: None = ...,
) -> Array: ... ) -> Array: ...
def sort_complex(a: ArrayLike) -> Array: ... def sort_complex(a: ArrayLike) -> Array: ...
def spacing(x: ArrayLike, /) -> Array: ...
def split( def split(
ary: ArrayLike, ary: ArrayLike,
indices_or_sections: int | Sequence[int] | ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike,

View File

@ -126,6 +126,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16], op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("spacing", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance=0),
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
@ -701,6 +703,34 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
dx = jax.grad(jax.numpy.i0)(0.0) dx = jax.grad(jax.numpy.i0)(0.0)
self.assertArraysEqual(dx, 0.0) self.assertArraysEqual(dx, 0.0)
@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes,
)
def testSpacingIntegerInputs(self, shape, dtype):
rng = jtu.rand_int(self.rng(), low=-64, high=64)
args_maker = lambda: [rng(shape, dtype)]
computation_dtype = jnp.spacing(rng(shape, dtype)).dtype
np_func = lambda x: np.spacing(np.array(x).astype(computation_dtype))
self._CheckAgainstNumpy(np_func, jnp.spacing, args_maker, check_dtypes=True, tol=0)
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
@jtu.sample_product(dtype = float_dtypes)
@jtu.skip_on_devices("tpu")
def testSpacingSubnormals(self, dtype):
zero = np.array(0, dtype=dtype)
inf = np.array(np.inf, dtype=dtype)
x = [zero]
for i in range(5):
x.append(np.nextafter(x[-1], -inf)) # negative denormals
x = x[::-1]
for i in range(5):
x.append(np.nextafter(x[-1], inf)) # positive denormals
x = np.array(x, dtype=dtype)
args_maker = lambda: [x]
self._CheckAgainstNumpy(np.spacing, jnp.spacing, args_maker, check_dtypes=True, tol=0)
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader()) absltest.main(testLoader=jtu.JaxTestLoader())