Merge pull request #6777 from paul-tqh-nguyen:numpy_poly

PiperOrigin-RevId: 374935887
This commit is contained in:
jax authors 2021-05-20 13:13:09 -07:00
commit 5cd416a3f2
4 changed files with 61 additions and 2 deletions

View File

@ -286,6 +286,7 @@ Not every function in NumPy is implemented; contributions are welcome!
pad
percentile
piecewise
poly
polyadd
polyder
polyint
@ -456,4 +457,4 @@ instantiate :class:`DeviceArray`s manually, but rather will create them via
.. autoclass:: jaxlib.xla_extension.DeviceArray
:members:
:inherited-members:
:inherited-members:

View File

@ -3847,6 +3847,43 @@ def diagflat(v, k=0):
res = res.reshape(adj_length,adj_length)
return res
_POLY_DOC="""\
This differs from np.poly when an integer array is given.
np.poly returns a result with dtype float64 in this case.
jax returns a result with an inexact type, but not necessarily
float64.
This also differs from np.poly when the input array strictly
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
np.poly returns an array with a real dtype in such cases.
jax returns an array with a complex dtype in such cases.
"""
@_wraps(np.poly, lax_description=_POLY_DOC)
def poly(seq_of_zeros):
_check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
sh = seq_of_zeros.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from . import linalg
seq_of_zeros = linalg.eigvals(seq_of_zeros)
if seq_of_zeros.ndim != 1:
raise ValueError("input must be 1d or non-empty square 2d array.")
dt = seq_of_zeros.dtype
if len(seq_of_zeros) == 0:
return ones((), dtype=dt)
a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros)):
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
return a
@_wraps(np.polyval)
def polyval(p, x):

View File

@ -51,7 +51,7 @@ from jax._src.numpy.lax_numpy import (
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
negative, newaxis, nextafter, nonzero, not_equal, number,
object_, ogrid, ones, ones_like, operator_name, outer, packbits, pad, percentile,
pi, piecewise, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
pi, piecewise, poly, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
prod, product, promote_types, ptp, quantile,
r_, rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
result_type, right_shift, rint, roll, rollaxis, rot90, round, row_stack,

View File

@ -1752,6 +1752,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_rank{}".format(
jtu.format_shape_dtype_string(a_shape, dtype), rank),
"dtype": dtype, "a_shape": a_shape, "rank": rank}
for rank in (1, 2)
for dtype in default_dtypes
for a_shape in one_dim_array_shapes))
def testPoly(self, a_shape, dtype, rank):
if dtype in (np.float16, jnp.bfloat16, np.int16):
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
elif rank == 2 and jtu.device_under_test() in ("tpu", "gpu"):
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.")
rng = jtu.rand_default(self.rng())
tol = { np.int8: 1e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
if jtu.device_under_test() == "tpu":
tol[np.int32] = tol[np.float32] = 1e-1
tol = jtu.tolerance(dtype, tol)
args_maker = lambda: [rng(a_shape * rank, dtype)]
self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "a_shape={} , b_shape={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),