Merge pull request #10226 from ljjsalt:add-polydiv

PiperOrigin-RevId: 441548874
This commit is contained in:
jax authors 2022-04-13 12:27:22 -07:00
commit 191c83816c
5 changed files with 65 additions and 1 deletions

View File

@ -305,6 +305,7 @@ namespace; they are listed below.
poly
polyadd
polyder
polydiv
polyfit
polyint
polymul

View File

@ -2528,6 +2528,17 @@ def trim_zeros(filt, trim='fb'):
return filt[start:len(filt) - end]
def trim_zeros_tol(filt, tol, trim='fb'):
filt = core.concrete_or_error(asarray, filt,
"Error arose in the `filt` argument of trim_zeros_tol()")
nz = (abs(filt) < tol)
if all(nz):
return empty(0, _dtype(filt))
start = argmin(nz) if 'f' in trim.lower() else 0
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
@_wraps(np.append)
@partial(jit, static_argnames=('axis',))
def append(arr, values, axis: Optional[int] = None):

View File

@ -21,7 +21,7 @@ from jax import jit
from jax import lax
from jax._src.numpy.lax_numpy import (
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo,
full, hstack, maximum, ones, outer, sqrt, trim_zeros, true_divide, vander, zeros)
full, hstack, maximum, ones, outer, sqrt, trim_zeros, trim_zeros_tol, true_divide, vander, zeros)
from jax._src.numpy import linalg
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps
import numpy as np
@ -277,6 +277,9 @@ def polyder(p, m=1):
_LEADING_ZEROS_DOC = """\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be
considered zero may lead to inconsistent results between NumPy and JAX, and even between different
JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.
"""
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
@ -292,6 +295,23 @@ def polymul(a1, a2, *, trim_leading_zeros=False):
val = convolve(a1, a2, mode='full')
return val
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
def polydiv(u, v, *, trim_leading_zeros=False):
_check_arraylike("polydiv", u, v)
u, v = _promote_dtypes_inexact(u, v)
m = len(u) - 1
n = len(v) - 1
scale = 1. / v[0]
q = zeros(max(m - n + 1, 1), dtype = u.dtype) # force same dtype
for k in range(0, m-n+1):
d = scale * u[k]
q = q.at[k].set(d)
u = u.at[k:k+n+1].add(-d*v)
if trim_leading_zeros:
# use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f')
else:
return q, u
@_wraps(np.polysub)
@jit

View File

@ -298,6 +298,7 @@ from jax._src.numpy.polynomial import (
poly as poly,
polyadd as polyadd,
polyder as polyder,
polydiv as polydiv,
polyfit as polyfit,
polyint as polyint,
polymul as polymul,

View File

@ -2874,6 +2874,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "a_shape={} , b_shape={}".format(
jtu.format_shape_dtype_string(a_shape, dtype),
jtu.format_shape_dtype_string(b_shape, dtype)),
"dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape}
for dtype in default_dtypes
for a_shape in one_dim_array_shapes
for b_shape in one_dim_array_shapes))
def testPolyDiv(self, a_shape, b_shape, dtype):
rng = jtu.rand_default(self.rng())
def np_fun(arg1, arg2):
q, r = np.polydiv(arg1, arg2)
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
r = np.pad(r, (1, 0), 'constant')
return q, r
def jnp_fun(arg1, arg2):
q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True)
while r.size < max(arg1.size, arg2.size): # Pad residual to same size
r = jnp.pad(r, (1, 0), 'constant')
return q, r
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13}
jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),