diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index d233641ee..34c1cf85d 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -305,6 +305,7 @@ namespace; they are listed below. poly polyadd polyder + polydiv polyfit polyint polymul diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6d906ec71..f428671aa 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2543,6 +2543,17 @@ def trim_zeros(filt, trim='fb'): return filt[start:len(filt) - end] +def trim_zeros_tol(filt, tol, trim='fb'): + filt = core.concrete_or_error(asarray, filt, + "Error arose in the `filt` argument of trim_zeros_tol()") + nz = (abs(filt) < tol) + if all(nz): + return empty(0, _dtype(filt)) + start = argmin(nz) if 'f' in trim.lower() else 0 + end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 + return filt[start:len(filt) - end] + + @_wraps(np.append) @partial(jit, static_argnames=('axis',)) def append(arr, values, axis: Optional[int] = None): diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 552ca002b..cf8c1803a 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -21,7 +21,7 @@ from jax import jit from jax import lax from jax._src.numpy.lax_numpy import ( all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo, - full, hstack, maximum, ones, outer, sqrt, trim_zeros, true_divide, vander, zeros) + full, hstack, maximum, ones, outer, sqrt, trim_zeros, trim_zeros_tol, true_divide, vander, zeros) from jax._src.numpy import linalg from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps import numpy as np @@ -277,6 +277,9 @@ def polyder(p, m=1): _LEADING_ZEROS_DOC = """\ Setting trim_leading_zeros=True makes the output match that of numpy. But prevents the function from being able to be used in compiled code. +Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be +considered zero may lead to inconsistent results between NumPy and JAX, and even between different +JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True. """ @_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC) @@ -292,6 +295,23 @@ def polymul(a1, a2, *, trim_leading_zeros=False): val = convolve(a1, a2, mode='full') return val +@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC) +def polydiv(u, v, *, trim_leading_zeros=False): + _check_arraylike("polydiv", u, v) + u, v = _promote_dtypes_inexact(u, v) + m = len(u) - 1 + n = len(v) - 1 + scale = 1. / v[0] + q = zeros(max(m - n + 1, 1), dtype = u.dtype) # force same dtype + for k in range(0, m-n+1): + d = scale * u[k] + q = q.at[k].set(d) + u = u.at[k:k+n+1].add(-d*v) + if trim_leading_zeros: + # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy + return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f') + else: + return q, u @_wraps(np.polysub) @jit diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 0f3fd8968..a2fad4020 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -298,6 +298,7 @@ from jax._src.numpy.polynomial import ( poly as poly, polyadd as polyadd, polyder as polyder, + polydiv as polydiv, polyfit as polyfit, polyint as polyint, polymul as polymul, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e0e046bd5..a4d2a0c01 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2874,6 +2874,37 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "a_shape={} , b_shape={}".format( + jtu.format_shape_dtype_string(a_shape, dtype), + jtu.format_shape_dtype_string(b_shape, dtype)), + "dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape} + for dtype in default_dtypes + for a_shape in one_dim_array_shapes + for b_shape in one_dim_array_shapes)) + def testPolyDiv(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + + def np_fun(arg1, arg2): + q, r = np.polydiv(arg1, arg2) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = np.pad(r, (1, 0), 'constant') + return q, r + + def jnp_fun(arg1, arg2): + q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = jnp.pad(r, (1, 0), 'constant') + return q, r + + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} + + jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),