mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10226 from ljjsalt:add-polydiv
PiperOrigin-RevId: 441548874
This commit is contained in:
commit
191c83816c
@ -305,6 +305,7 @@ namespace; they are listed below.
|
||||
poly
|
||||
polyadd
|
||||
polyder
|
||||
polydiv
|
||||
polyfit
|
||||
polyint
|
||||
polymul
|
||||
|
@ -2528,6 +2528,17 @@ def trim_zeros(filt, trim='fb'):
|
||||
return filt[start:len(filt) - end]
|
||||
|
||||
|
||||
def trim_zeros_tol(filt, tol, trim='fb'):
|
||||
filt = core.concrete_or_error(asarray, filt,
|
||||
"Error arose in the `filt` argument of trim_zeros_tol()")
|
||||
nz = (abs(filt) < tol)
|
||||
if all(nz):
|
||||
return empty(0, _dtype(filt))
|
||||
start = argmin(nz) if 'f' in trim.lower() else 0
|
||||
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
|
||||
return filt[start:len(filt) - end]
|
||||
|
||||
|
||||
@_wraps(np.append)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def append(arr, values, axis: Optional[int] = None):
|
||||
|
@ -21,7 +21,7 @@ from jax import jit
|
||||
from jax import lax
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo,
|
||||
full, hstack, maximum, ones, outer, sqrt, trim_zeros, true_divide, vander, zeros)
|
||||
full, hstack, maximum, ones, outer, sqrt, trim_zeros, trim_zeros_tol, true_divide, vander, zeros)
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps
|
||||
import numpy as np
|
||||
@ -277,6 +277,9 @@ def polyder(p, m=1):
|
||||
_LEADING_ZEROS_DOC = """\
|
||||
Setting trim_leading_zeros=True makes the output match that of numpy.
|
||||
But prevents the function from being able to be used in compiled code.
|
||||
Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be
|
||||
considered zero may lead to inconsistent results between NumPy and JAX, and even between different
|
||||
JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.
|
||||
"""
|
||||
|
||||
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
|
||||
@ -292,6 +295,23 @@ def polymul(a1, a2, *, trim_leading_zeros=False):
|
||||
val = convolve(a1, a2, mode='full')
|
||||
return val
|
||||
|
||||
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polydiv(u, v, *, trim_leading_zeros=False):
|
||||
_check_arraylike("polydiv", u, v)
|
||||
u, v = _promote_dtypes_inexact(u, v)
|
||||
m = len(u) - 1
|
||||
n = len(v) - 1
|
||||
scale = 1. / v[0]
|
||||
q = zeros(max(m - n + 1, 1), dtype = u.dtype) # force same dtype
|
||||
for k in range(0, m-n+1):
|
||||
d = scale * u[k]
|
||||
q = q.at[k].set(d)
|
||||
u = u.at[k:k+n+1].add(-d*v)
|
||||
if trim_leading_zeros:
|
||||
# use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
|
||||
return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f')
|
||||
else:
|
||||
return q, u
|
||||
|
||||
@_wraps(np.polysub)
|
||||
@jit
|
||||
|
@ -298,6 +298,7 @@ from jax._src.numpy.polynomial import (
|
||||
poly as poly,
|
||||
polyadd as polyadd,
|
||||
polyder as polyder,
|
||||
polydiv as polydiv,
|
||||
polyfit as polyfit,
|
||||
polyint as polyint,
|
||||
polymul as polymul,
|
||||
|
@ -2874,6 +2874,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "a_shape={} , b_shape={}".format(
|
||||
jtu.format_shape_dtype_string(a_shape, dtype),
|
||||
jtu.format_shape_dtype_string(b_shape, dtype)),
|
||||
"dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape}
|
||||
for dtype in default_dtypes
|
||||
for a_shape in one_dim_array_shapes
|
||||
for b_shape in one_dim_array_shapes))
|
||||
def testPolyDiv(self, a_shape, b_shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
def np_fun(arg1, arg2):
|
||||
q, r = np.polydiv(arg1, arg2)
|
||||
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
|
||||
r = np.pad(r, (1, 0), 'constant')
|
||||
return q, r
|
||||
|
||||
def jnp_fun(arg1, arg2):
|
||||
q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True)
|
||||
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
|
||||
r = jnp.pad(r, (1, 0), 'constant')
|
||||
return q, r
|
||||
|
||||
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
|
||||
tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13}
|
||||
|
||||
jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
|
||||
self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),
|
||||
|
Loading…
x
Reference in New Issue
Block a user