mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Implement jax.numpy.spacing
Somehow we've missed this numpy API up until now.
This commit is contained in:
parent
5800070c36
commit
635e29a0b9
@ -376,6 +376,7 @@ namespace; they are listed below.
|
|||||||
size
|
size
|
||||||
sort
|
sort
|
||||||
sort_complex
|
sort_complex
|
||||||
|
spacing
|
||||||
split
|
split
|
||||||
sqrt
|
sqrt
|
||||||
square
|
square
|
||||||
|
@ -1451,6 +1451,50 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array:
|
|||||||
"""
|
"""
|
||||||
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
|
return lax.nextafter(*promote_args_inexact("nextafter", x, y))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, inline=True)
|
||||||
|
def spacing(x: ArrayLike, /) -> Array:
|
||||||
|
"""Return the spacing between ``x`` and the next adjacent number.
|
||||||
|
|
||||||
|
JAX implementation of :func:`numpy.spacing`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: real-valued array. Integer or boolean types will be cast to float.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of same shape as ``x`` containing spacing between each entry of
|
||||||
|
``x`` and its closest adjacent value.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.nextafter`: find the next representable value.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0], dtype='float32')
|
||||||
|
>>> jnp.spacing(x)
|
||||||
|
Array([1.4012985e-45, 2.9802322e-08, 5.9604645e-08, 5.9604645e-08,
|
||||||
|
1.1920929e-07], dtype=float32)
|
||||||
|
|
||||||
|
For ``x = 1``, the spacing is equal to the ``eps`` value given by
|
||||||
|
:class:`jax.numpy.finfo`:
|
||||||
|
|
||||||
|
>>> x = jnp.float32(1)
|
||||||
|
>>> jnp.spacing(x) == jnp.finfo(x.dtype).eps
|
||||||
|
Array(True, dtype=bool)
|
||||||
|
"""
|
||||||
|
arr, = promote_args_inexact("spacing", x)
|
||||||
|
if dtypes.isdtype(arr.dtype, "complex floating"):
|
||||||
|
raise ValueError("jnp.spacing is not defined for complex inputs.")
|
||||||
|
inf = _lax_const(arr, np.inf)
|
||||||
|
smallest_subnormal = dtypes.finfo(arr.dtype).smallest_subnormal
|
||||||
|
|
||||||
|
# Numpy's behavior seems to depend on dtype
|
||||||
|
if arr.dtype == 'float16':
|
||||||
|
return lax.nextafter(arr, inf) - arr
|
||||||
|
else:
|
||||||
|
result = lax.nextafter(arr, copysign(inf, arr)) - arr
|
||||||
|
return _where(result == 0, copysign(smallest_subnormal, arr), result)
|
||||||
|
|
||||||
|
|
||||||
# Logical ops
|
# Logical ops
|
||||||
@partial(jit, inline=True)
|
@partial(jit, inline=True)
|
||||||
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||||
|
@ -443,6 +443,7 @@ from jax._src.numpy.ufuncs import (
|
|||||||
sin as sin,
|
sin as sin,
|
||||||
sinc as sinc,
|
sinc as sinc,
|
||||||
sinh as sinh,
|
sinh as sinh,
|
||||||
|
spacing as spacing,
|
||||||
sqrt as sqrt,
|
sqrt as sqrt,
|
||||||
square as square,
|
square as square,
|
||||||
subtract as subtract,
|
subtract as subtract,
|
||||||
|
@ -808,6 +808,7 @@ def sort(
|
|||||||
order: None = ...,
|
order: None = ...,
|
||||||
) -> Array: ...
|
) -> Array: ...
|
||||||
def sort_complex(a: ArrayLike) -> Array: ...
|
def sort_complex(a: ArrayLike) -> Array: ...
|
||||||
|
def spacing(x: ArrayLike, /) -> Array: ...
|
||||||
def split(
|
def split(
|
||||||
ary: ArrayLike,
|
ary: ArrayLike,
|
||||||
indices_or_sections: int | Sequence[int] | ArrayLike,
|
indices_or_sections: int | Sequence[int] | ArrayLike,
|
||||||
|
@ -126,6 +126,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
|||||||
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||||
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
|
op_record("nextafter", 2, [f for f in float_dtypes if f != jnp.bfloat16],
|
||||||
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
|
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
|
||||||
|
op_record("spacing", 1, float_dtypes, all_shapes, jtu.rand_default, ["rev"],
|
||||||
|
inexact=True, tolerance=0),
|
||||||
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
op_record("not_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||||
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||||
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
op_record("array_equiv", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||||
@ -701,6 +703,34 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
|
|||||||
dx = jax.grad(jax.numpy.i0)(0.0)
|
dx = jax.grad(jax.numpy.i0)(0.0)
|
||||||
self.assertArraysEqual(dx, 0.0)
|
self.assertArraysEqual(dx, 0.0)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
shape=all_shapes,
|
||||||
|
dtype=default_dtypes,
|
||||||
|
)
|
||||||
|
def testSpacingIntegerInputs(self, shape, dtype):
|
||||||
|
rng = jtu.rand_int(self.rng(), low=-64, high=64)
|
||||||
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
|
computation_dtype = jnp.spacing(rng(shape, dtype)).dtype
|
||||||
|
np_func = lambda x: np.spacing(np.array(x).astype(computation_dtype))
|
||||||
|
self._CheckAgainstNumpy(np_func, jnp.spacing, args_maker, check_dtypes=True, tol=0)
|
||||||
|
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
|
||||||
|
|
||||||
|
@jtu.sample_product(dtype = float_dtypes)
|
||||||
|
@jtu.skip_on_devices("tpu")
|
||||||
|
def testSpacingSubnormals(self, dtype):
|
||||||
|
zero = np.array(0, dtype=dtype)
|
||||||
|
inf = np.array(np.inf, dtype=dtype)
|
||||||
|
x = [zero]
|
||||||
|
for i in range(5):
|
||||||
|
x.append(np.nextafter(x[-1], -inf)) # negative denormals
|
||||||
|
x = x[::-1]
|
||||||
|
for i in range(5):
|
||||||
|
x.append(np.nextafter(x[-1], inf)) # positive denormals
|
||||||
|
x = np.array(x, dtype=dtype)
|
||||||
|
args_maker = lambda: [x]
|
||||||
|
self._CheckAgainstNumpy(np.spacing, jnp.spacing, args_maker, check_dtypes=True, tol=0)
|
||||||
|
self._CompileAndCheck(jnp.spacing, args_maker, tol=0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user