Merge pull request #5731 from terhorst:master

PiperOrigin-RevId: 358951861
This commit is contained in:
jax authors 2021-02-22 18:53:38 -08:00
commit 9df3454ee0
4 changed files with 41 additions and 1 deletions

View File

@ -283,6 +283,7 @@ Not every function in NumPy is implemented; contributions are welcome!
piecewise
polyadd
polyder
polyint
polymul
polysub
polyval

View File

@ -3484,6 +3484,26 @@ def polyadd(a1, a2):
return a2.at[-a1.shape[0]:].add(a1)
@_wraps(np.polyint)
def polyint(p, m=1, k=None):
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
p = asarray(p)
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
if k is None:
k = zeros(m)
k = atleast_1d(k)
if len(k) == 1:
k = full((m,), k[0])
if len(k) != m or k.ndim > 1:
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p
else:
coeff = maximum(1, arange(len(p) + m, 0, -1) - 1 - arange(m)[:, newaxis]).prod(0)
return true_divide(concatenate((p, k)), coeff)
@_wraps(np.polyder)
def polyder(p, m=1):
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")

View File

@ -51,7 +51,7 @@ from jax._src.numpy.lax_numpy import (
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
negative, newaxis, nextafter, nonzero, not_equal, number, numpy_version,
object_, ones, ones_like, operator_name, outer, packbits, pad, percentile,
pi, piecewise, polyadd, polyder, polymul, polysub, polyval, positive, power,
pi, piecewise, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
prod, product, promote_types, ptp, quantile,
rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
result_type, right_shift, rint, roll, rollaxis, rot90, round, row_stack,

View File

@ -1674,6 +1674,25 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_order={}_k={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
order, k),
"dtype": dtype, "a_shape": a_shape, "order" : order, "k": k}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for order in range(5)
for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]
))
def testPolyInt(self, a_shape, order, k, dtype):
rng = jtu.rand_default(self.rng())
np_fun = lambda arg1: np.polyint(arg1, m=order, k=k)
jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k)
args_maker = lambda: [rng(a_shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_order={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),