mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #5731 from terhorst:master
PiperOrigin-RevId: 358951861
This commit is contained in:
commit
9df3454ee0
@ -283,6 +283,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
piecewise
|
||||
polyadd
|
||||
polyder
|
||||
polyint
|
||||
polymul
|
||||
polysub
|
||||
polyval
|
||||
|
@ -3484,6 +3484,26 @@ def polyadd(a1, a2):
|
||||
return a2.at[-a1.shape[0]:].add(a1)
|
||||
|
||||
|
||||
@_wraps(np.polyint)
|
||||
def polyint(p, m=1, k=None):
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
|
||||
p = asarray(p)
|
||||
if m < 0:
|
||||
raise ValueError("Order of integral must be positive (see polyder)")
|
||||
if k is None:
|
||||
k = zeros(m)
|
||||
k = atleast_1d(k)
|
||||
if len(k) == 1:
|
||||
k = full((m,), k[0])
|
||||
if len(k) != m or k.ndim > 1:
|
||||
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
|
||||
if m == 0:
|
||||
return p
|
||||
else:
|
||||
coeff = maximum(1, arange(len(p) + m, 0, -1) - 1 - arange(m)[:, newaxis]).prod(0)
|
||||
return true_divide(concatenate((p, k)), coeff)
|
||||
|
||||
|
||||
@_wraps(np.polyder)
|
||||
def polyder(p, m=1):
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
|
||||
|
@ -51,7 +51,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
|
||||
negative, newaxis, nextafter, nonzero, not_equal, number, numpy_version,
|
||||
object_, ones, ones_like, operator_name, outer, packbits, pad, percentile,
|
||||
pi, piecewise, polyadd, polyder, polymul, polysub, polyval, positive, power,
|
||||
pi, piecewise, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
|
||||
prod, product, promote_types, ptp, quantile,
|
||||
rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
|
||||
result_type, right_shift, rint, roll, rollaxis, rot90, round, row_stack,
|
||||
|
@ -1674,6 +1674,25 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_order={}_k={}".format(
|
||||
jtu.format_shape_dtype_string(a_shape, dtype),
|
||||
order, k),
|
||||
"dtype": dtype, "a_shape": a_shape, "order" : order, "k": k}
|
||||
for dtype in default_dtypes
|
||||
for a_shape in one_dim_array_shapes
|
||||
for order in range(5)
|
||||
for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]
|
||||
))
|
||||
def testPolyInt(self, a_shape, order, k, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np_fun = lambda arg1: np.polyint(arg1, m=order, k=k)
|
||||
jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k)
|
||||
args_maker = lambda: [rng(a_shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_order={}".format(
|
||||
jtu.format_shape_dtype_string(a_shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user