mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #6777 from paul-tqh-nguyen:numpy_poly
PiperOrigin-RevId: 374935887
This commit is contained in:
commit
5cd416a3f2
@ -286,6 +286,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
pad
|
||||
percentile
|
||||
piecewise
|
||||
poly
|
||||
polyadd
|
||||
polyder
|
||||
polyint
|
||||
@ -456,4 +457,4 @@ instantiate :class:`DeviceArray`s manually, but rather will create them via
|
||||
|
||||
.. autoclass:: jaxlib.xla_extension.DeviceArray
|
||||
:members:
|
||||
:inherited-members:
|
||||
:inherited-members:
|
||||
|
@ -3847,6 +3847,43 @@ def diagflat(v, k=0):
|
||||
res = res.reshape(adj_length,adj_length)
|
||||
return res
|
||||
|
||||
_POLY_DOC="""\
|
||||
This differs from np.poly when an integer array is given.
|
||||
np.poly returns a result with dtype float64 in this case.
|
||||
jax returns a result with an inexact type, but not necessarily
|
||||
float64.
|
||||
|
||||
This also differs from np.poly when the input array strictly
|
||||
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
|
||||
np.poly returns an array with a real dtype in such cases.
|
||||
jax returns an array with a complex dtype in such cases.
|
||||
"""
|
||||
|
||||
@_wraps(np.poly, lax_description=_POLY_DOC)
|
||||
def poly(seq_of_zeros):
|
||||
_check_arraylike('poly', seq_of_zeros)
|
||||
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
|
||||
seq_of_zeros = atleast_1d(seq_of_zeros)
|
||||
|
||||
sh = seq_of_zeros.shape
|
||||
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
|
||||
# import at runtime to avoid circular import
|
||||
from . import linalg
|
||||
seq_of_zeros = linalg.eigvals(seq_of_zeros)
|
||||
|
||||
if seq_of_zeros.ndim != 1:
|
||||
raise ValueError("input must be 1d or non-empty square 2d array.")
|
||||
|
||||
dt = seq_of_zeros.dtype
|
||||
if len(seq_of_zeros) == 0:
|
||||
return ones((), dtype=dt)
|
||||
|
||||
a = ones((1,), dtype=dt)
|
||||
for k in range(len(seq_of_zeros)):
|
||||
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@_wraps(np.polyval)
|
||||
def polyval(p, x):
|
||||
|
@ -51,7 +51,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
|
||||
negative, newaxis, nextafter, nonzero, not_equal, number,
|
||||
object_, ogrid, ones, ones_like, operator_name, outer, packbits, pad, percentile,
|
||||
pi, piecewise, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
|
||||
pi, piecewise, poly, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
|
||||
prod, product, promote_types, ptp, quantile,
|
||||
r_, rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
|
||||
result_type, right_shift, rint, roll, rollaxis, rot90, round, row_stack,
|
||||
|
@ -1752,6 +1752,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_rank{}".format(
|
||||
jtu.format_shape_dtype_string(a_shape, dtype), rank),
|
||||
"dtype": dtype, "a_shape": a_shape, "rank": rank}
|
||||
for rank in (1, 2)
|
||||
for dtype in default_dtypes
|
||||
for a_shape in one_dim_array_shapes))
|
||||
def testPoly(self, a_shape, dtype, rank):
|
||||
if dtype in (np.float16, jnp.bfloat16, np.int16):
|
||||
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
|
||||
elif rank == 2 and jtu.device_under_test() in ("tpu", "gpu"):
|
||||
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = { np.int8: 1e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol[np.int32] = tol[np.float32] = 1e-1
|
||||
tol = jtu.tolerance(dtype, tol)
|
||||
args_maker = lambda: [rng(a_shape * rank, dtype)]
|
||||
self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "a_shape={} , b_shape={}".format(
|
||||
jtu.format_shape_dtype_string(a_shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user