diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index d6b7d74bd..9eb518464 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -376,6 +376,7 @@ namespace; they are listed below. size sort sort_complex + spacing split sqrt square diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index abeb7b775..d7afab1e7 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -1451,6 +1451,50 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.nextafter(*promote_args_inexact("nextafter", x, y)) + +@partial(jit, inline=True) +def spacing(x: ArrayLike, /) -> Array: + """Return the spacing between ``x`` and the next adjacent number. + + JAX implementation of :func:`numpy.spacing`. + + Args: + x: real-valued array. Integer or boolean types will be cast to float. + + Returns: + Array of same shape as ``x`` containing spacing between each entry of + ``x`` and its closest adjacent value. + + See also: + - :func:`jax.numpy.nextafter`: find the next representable value. + + Examples: + >>> x = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0], dtype='float32') + >>> jnp.spacing(x) + Array([1.4012985e-45, 2.9802322e-08, 5.9604645e-08, 5.9604645e-08, + 1.1920929e-07], dtype=float32) + + For ``x = 1``, the spacing is equal to the ``eps`` value given by + :class:`jax.numpy.finfo`: + + >>> x = jnp.float32(1) + >>> jnp.spacing(x) == jnp.finfo(x.dtype).eps + Array(True, dtype=bool) + """ + arr, = promote_args_inexact("spacing", x) + if dtypes.isdtype(arr.dtype, "complex floating"): + raise ValueError("jnp.spacing is not defined for complex inputs.") + inf = _lax_const(arr, np.inf) + smallest_subnormal = dtypes.finfo(arr.dtype).smallest_subnormal + + # Numpy's behavior seems to depend on dtype + if arr.dtype == 'float16': + return lax.nextafter(arr, inf) - arr + else: + result = lax.nextafter(arr, copysign(inf, arr)) - arr + return _where(result == 0, copysign(smallest_subnormal, arr), result) + + # Logical ops @partial(jit, inline=True) def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 20c37c559..bd8068729 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -443,6 +443,7 @@ from jax._src.numpy.ufuncs import ( sin as sin, sinc as sinc, sinh as sinh, + spacing as spacing, sqrt as sqrt, square as square, subtract as subtract, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index fd2524bc3..0ea9b5ee7 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -808,6 +808,7 @@ def sort( order: None = ..., ) -> Array: ... def sort_complex(a: ArrayLike) -> Array: ... +def spacing(x: ArrayLike, /) -> Array: ... def split( ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 45a780c9f..744a99fb7 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -126,6 +126,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]), op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16], all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0), + op_record("spacing", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"], + inexact=True, tolerance=0), op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]), @@ -701,6 +703,34 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase): dx = jax.grad(jax.numpy.i0)(0.0) self.assertArraysEqual(dx, 0.0) + @jtu.sample_product( + shape=all_shapes, + dtype=default_dtypes, + ) + def testSpacingIntegerInputs(self, shape, dtype): + rng = jtu.rand_int(self.rng(), low=-64, high=64) + args_maker = lambda: [rng(shape, dtype)] + computation_dtype = jnp.spacing(rng(shape, dtype)).dtype + np_func = lambda x: np.spacing(np.array(x).astype(computation_dtype)) + self._CheckAgainstNumpy(np_func, jnp.spacing, args_maker, check_dtypes=True, tol=0) + self._CompileAndCheck(jnp.spacing, args_maker, tol=0) + + @jtu.sample_product(dtype = float_dtypes) + @jtu.skip_on_devices("tpu") + def testSpacingSubnormals(self, dtype): + zero = np.array(0, dtype=dtype) + inf = np.array(np.inf, dtype=dtype) + x = [zero] + for i in range(5): + x.append(np.nextafter(x[-1], -inf)) # negative denormals + x = x[::-1] + for i in range(5): + x.append(np.nextafter(x[-1], inf)) # positive denormals + x = np.array(x, dtype=dtype) + args_maker = lambda: [x] + self._CheckAgainstNumpy(np.spacing, jnp.spacing, args_maker, check_dtypes=True, tol=0) + self._CompileAndCheck(jnp.spacing, args_maker, tol=0) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())