mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Implement jax.numpy.spacing
Somehow we've missed this numpy API up until now.
This commit is contained in:
parent
5800070c36
commit
635e29a0b9
@ -376,6 +376,7 @@ namespace; they are listed below.
|
||||
size
|
||||
sort
|
||||
sort_complex
|
||||
spacing
|
||||
split
|
||||
sqrt
|
||||
square
|
||||
|
@ -1451,6 +1451,50 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
"""
|
||||
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
|
||||
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def spacing(x: ArrayLike, /) -> Array:
|
||||
"""Return the spacing between ``x`` and the next adjacent number.
|
||||
|
||||
JAX implementation of :func:`numpy.spacing`.
|
||||
|
||||
Args:
|
||||
x: real-valued array. Integer or boolean types will be cast to float.
|
||||
|
||||
Returns:
|
||||
Array of same shape as ``x`` containing spacing between each entry of
|
||||
``x`` and its closest adjacent value.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.nextafter`: find the next representable value.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0], dtype='float32')
|
||||
>>> jnp.spacing(x)
|
||||
Array([1.4012985e-45, 2.9802322e-08, 5.9604645e-08, 5.9604645e-08,
|
||||
1.1920929e-07], dtype=float32)
|
||||
|
||||
For ``x = 1``, the spacing is equal to the ``eps`` value given by
|
||||
:class:`jax.numpy.finfo`:
|
||||
|
||||
>>> x = jnp.float32(1)
|
||||
>>> jnp.spacing(x) == jnp.finfo(x.dtype).eps
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
arr, = promote_args_inexact("spacing", x)
|
||||
if dtypes.isdtype(arr.dtype, "complex floating"):
|
||||
raise ValueError("jnp.spacing is not defined for complex inputs.")
|
||||
inf = _lax_const(arr, np.inf)
|
||||
smallest_subnormal = dtypes.finfo(arr.dtype).smallest_subnormal
|
||||
|
||||
# Numpy's behavior seems to depend on dtype
|
||||
if arr.dtype == 'float16':
|
||||
return lax.nextafter(arr, inf) - arr
|
||||
else:
|
||||
result = lax.nextafter(arr, copysign(inf, arr)) - arr
|
||||
return _where(result == 0, copysign(smallest_subnormal, arr), result)
|
||||
|
||||
|
||||
# Logical ops
|
||||
@partial(jit, inline=True)
|
||||
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
@ -443,6 +443,7 @@ from jax._src.numpy.ufuncs import (
|
||||
sin as sin,
|
||||
sinc as sinc,
|
||||
sinh as sinh,
|
||||
spacing as spacing,
|
||||
sqrt as sqrt,
|
||||
square as square,
|
||||
subtract as subtract,
|
||||
|
@ -808,6 +808,7 @@ def sort(
|
||||
order: None = ...,
|
||||
) -> Array: ...
|
||||
def sort_complex(a: ArrayLike) -> Array: ...
|
||||
def spacing(x: ArrayLike, /) -> Array: ...
|
||||
def split(
|
||||
ary: ArrayLike,
|
||||
indices_or_sections: int | Sequence[int] | ArrayLike,
|
||||
|
@ -126,6 +126,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
||||
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
|
||||
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
|
||||
op_record("spacing", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||
inexact=True, tolerance=0),
|
||||
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
@ -701,6 +703,34 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
|
||||
dx = jax.grad(jax.numpy.i0)(0.0)
|
||||
self.assertArraysEqual(dx, 0.0)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testSpacingIntegerInputs(self, shape, dtype):
|
||||
rng = jtu.rand_int(self.rng(), low=-64, high=64)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
computation_dtype = jnp.spacing(rng(shape, dtype)).dtype
|
||||
np_func = lambda x: np.spacing(np.array(x).astype(computation_dtype))
|
||||
self._CheckAgainstNumpy(np_func, jnp.spacing, args_maker, check_dtypes=True, tol=0)
|
||||
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
|
||||
|
||||
@jtu.sample_product(dtype = float_dtypes)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testSpacingSubnormals(self, dtype):
|
||||
zero = np.array(0, dtype=dtype)
|
||||
inf = np.array(np.inf, dtype=dtype)
|
||||
x = [zero]
|
||||
for i in range(5):
|
||||
x.append(np.nextafter(x[-1], -inf)) # negative denormals
|
||||
x = x[::-1]
|
||||
for i in range(5):
|
||||
x.append(np.nextafter(x[-1], inf)) # positive denormals
|
||||
x = np.array(x, dtype=dtype)
|
||||
args_maker = lambda: [x]
|
||||
self._CheckAgainstNumpy(np.spacing, jnp.spacing, args_maker, check_dtypes=True, tol=0)
|
||||
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user