mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Implement jax.numpy.nextafter. (#1845)
This commit is contained in:
parent
c63bfca200
commit
3a07c69d0c
@ -90,6 +90,7 @@ Operators
|
||||
mul
|
||||
ne
|
||||
neg
|
||||
nextafter
|
||||
pad
|
||||
pow
|
||||
real
|
||||
|
@ -178,6 +178,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
nanprod
|
||||
nansum
|
||||
negative
|
||||
nextafter
|
||||
not_equal
|
||||
ones
|
||||
ones_like
|
||||
|
@ -118,6 +118,10 @@ def sign(x):
|
||||
"""
|
||||
return sign_p.bind(x)
|
||||
|
||||
def nextafter(x1, x2):
|
||||
r"""Returns the next representable value after `x1` in the direction of `x2`."""
|
||||
return nextafter_p.bind(_brcast(x1, x2), _brcast(x2, x1))
|
||||
|
||||
def floor(x):
|
||||
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
|
||||
return floor_p.bind(x)
|
||||
@ -184,13 +188,13 @@ def digamma(x):
|
||||
|
||||
def bessel_i0e(x):
|
||||
r"""Exponentially scaled modified Bessel function of order 0:
|
||||
:math:`\mathrm{i0e}(x) = e^{-\mathrm{abs}(x)} \mathrm{i0}(x)`
|
||||
:math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
|
||||
"""
|
||||
return bessel_i0e_p.bind(x)
|
||||
|
||||
def bessel_i1e(x):
|
||||
r"""Exponentially scaled modified Bessel function of order 1:
|
||||
:math:`\mathrm{i1e}(x) = e^{-\mathrm{abs}(x)} \mathrm{i1}(x)`
|
||||
:math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)`
|
||||
"""
|
||||
return bessel_i1e_p.bind(x)
|
||||
|
||||
@ -1622,6 +1626,10 @@ ad.deflinear(neg_p, lambda t: [neg(t)])
|
||||
sign_p = standard_unop(_num, 'sign')
|
||||
ad.defjvp_zero(sign_p)
|
||||
|
||||
nextafter_p = standard_binop(
|
||||
[_float, _float], 'nextafter',
|
||||
translation_rule=lambda c, x1, x2: c.NextAfter(x1, x2))
|
||||
|
||||
floor_p = standard_unop(_float, 'floor')
|
||||
ad.defjvp_zero(floor_p)
|
||||
|
||||
|
@ -37,6 +37,7 @@ sign = onp.sign
|
||||
floor = onp.floor
|
||||
ceil = onp.ceil
|
||||
round = onp.round
|
||||
nextafter = onp.nextafter
|
||||
|
||||
is_finite = onp.isfinite
|
||||
|
||||
|
@ -427,6 +427,7 @@ arctan2 = _one_to_one_binop(onp.arctan2, lax.atan2, True)
|
||||
minimum = _one_to_one_binop(onp.minimum, lax.min)
|
||||
maximum = _one_to_one_binop(onp.maximum, lax.max)
|
||||
float_power = _one_to_one_binop(onp.float_power, lax.pow, True)
|
||||
nextafter = _one_to_one_binop(onp.nextafter, lax.nextafter, True)
|
||||
|
||||
|
||||
def _comparison_op(numpy_fn, lax_fn):
|
||||
|
@ -123,6 +123,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
|
||||
op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf, []),
|
||||
op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
|
||||
op_record("nextafter", 2, [f for f in float_dtypes if f != lnp.bfloat16],
|
||||
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
|
||||
op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
|
||||
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),
|
||||
|
@ -81,6 +81,8 @@ LAX_OPS = [
|
||||
op_record("floor", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("ceil", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("round", 1, float_dtypes, jtu.rand_default),
|
||||
op_record("nextafter", 2, [f for f in float_dtypes if f != dtypes.bfloat16],
|
||||
jtu.rand_default, tol=0),
|
||||
|
||||
op_record("is_finite", 1, float_dtypes, jtu.rand_small),
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user