From 30b687c4863f86e7afce87333162a70e0ce66166 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 31 May 2022 14:22:49 -0700 Subject: [PATCH] [x64] make jnp.poly* functions work under strict dtype promotion --- jax/_src/numpy/polynomial.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index cf8c1803a..71ad4b01c 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -256,7 +256,9 @@ def polyint(p, m=1, k=None): if m == 0: return p else: - coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) + grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis] + - arange(m, dtype=p.dtype)[:, np.newaxis]) + coeff = maximum(1, grid).prod(0)[::-1] return true_divide(concatenate((p, k)), coeff) @@ -270,8 +272,9 @@ def polyder(p, m=1): raise ValueError("Order of derivative must be positive") if m == 0: return p - coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) - return p[:-m] * coeff + coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis] + - arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0) + return p[:-m] * coeff[::-1] _LEADING_ZEROS_DOC = """\ @@ -289,11 +292,10 @@ def polymul(a1, a2, *, trim_leading_zeros=False): if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1): a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f') if len(a1) == 0: - a1 = asarray([0.]) + a1 = asarray([0], dtype=a2.dtype) if len(a2) == 0: - a2 = asarray([0.]) - val = convolve(a1, a2, mode='full') - return val + a2 = asarray([0], dtype=a1.dtype) + return convolve(a1, a2, mode='full') @_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC) def polydiv(u, v, *, trim_leading_zeros=False):