Implement jax.numpy.nextafter. (#1845)

This commit is contained in:
Peter Hawkins 2019-12-11 16:41:24 -05:00 committed by GitHub
parent c63bfca200
commit 3a07c69d0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 2 deletions

View File

@ -90,6 +90,7 @@ Operators
mul
ne
neg
nextafter
pad
pow
real

View File

@ -178,6 +178,7 @@ Not every function in NumPy is implemented; contributions are welcome!
nanprod
nansum
negative
nextafter
not_equal
ones
ones_like

View File

@ -118,6 +118,10 @@ def sign(x):
"""
return sign_p.bind(x)
def nextafter(x1, x2):
r"""Returns the next representable value after `x1` in the direction of `x2`."""
return nextafter_p.bind(_brcast(x1, x2), _brcast(x2, x1))
def floor(x):
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
return floor_p.bind(x)
@ -184,13 +188,13 @@ def digamma(x):
def bessel_i0e(x):
r"""Exponentially scaled modified Bessel function of order 0:
:math:`\mathrm{i0e}(x) = e^{-\mathrm{abs}(x)} \mathrm{i0}(x)`
:math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
"""
return bessel_i0e_p.bind(x)
def bessel_i1e(x):
r"""Exponentially scaled modified Bessel function of order 1:
:math:`\mathrm{i1e}(x) = e^{-\mathrm{abs}(x)} \mathrm{i1}(x)`
:math:`\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)`
"""
return bessel_i1e_p.bind(x)
@ -1622,6 +1626,10 @@ ad.deflinear(neg_p, lambda t: [neg(t)])
sign_p = standard_unop(_num, 'sign')
ad.defjvp_zero(sign_p)
nextafter_p = standard_binop(
[_float, _float], 'nextafter',
translation_rule=lambda c, x1, x2: c.NextAfter(x1, x2))
floor_p = standard_unop(_float, 'floor')
ad.defjvp_zero(floor_p)

View File

@ -37,6 +37,7 @@ sign = onp.sign
floor = onp.floor
ceil = onp.ceil
round = onp.round
nextafter = onp.nextafter
is_finite = onp.isfinite

View File

@ -427,6 +427,7 @@ arctan2 = _one_to_one_binop(onp.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(onp.minimum, lax.min)
maximum = _one_to_one_binop(onp.maximum, lax.max)
float_power = _one_to_one_binop(onp.float_power, lax.pow, True)
nextafter = _one_to_one_binop(onp.nextafter, lax.nextafter, True)
def _comparison_op(numpy_fn, lax_fn):

View File

@ -123,6 +123,8 @@ JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf, []),
op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"]),
op_record("nextafter", 2, [f for f in float_dtypes if f != lnp.bfloat16],
all_shapes, jtu.rand_default, ["rev"], inexact=True, tolerance=0),
op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("array_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal, ["rev"]),
op_record("reciprocal", 1, inexact_dtypes, all_shapes, jtu.rand_default, []),

View File

@ -81,6 +81,8 @@ LAX_OPS = [
op_record("floor", 1, float_dtypes, jtu.rand_small),
op_record("ceil", 1, float_dtypes, jtu.rand_small),
op_record("round", 1, float_dtypes, jtu.rand_default),
op_record("nextafter", 2, [f for f in float_dtypes if f != dtypes.bfloat16],
jtu.rand_default, tol=0),
op_record("is_finite", 1, float_dtypes, jtu.rand_small),