[x64] make jnp.poly* functions work under strict dtype promotion

This commit is contained in:
Jake VanderPlas 2022-05-31 14:22:49 -07:00
parent bab8520d0c
commit 30b687c486

View File

@ -256,7 +256,9 @@ def polyint(p, m=1, k=None):
if m == 0:
return p
else:
coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis]
- arange(m, dtype=p.dtype)[:, np.newaxis])
coeff = maximum(1, grid).prod(0)[::-1]
return true_divide(concatenate((p, k)), coeff)
@ -270,8 +272,9 @@ def polyder(p, m=1):
raise ValueError("Order of derivative must be positive")
if m == 0:
return p
coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
return p[:-m] * coeff
coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis]
- arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0)
return p[:-m] * coeff[::-1]
_LEADING_ZEROS_DOC = """\
@ -289,11 +292,10 @@ def polymul(a1, a2, *, trim_leading_zeros=False):
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
if len(a1) == 0:
a1 = asarray([0.])
a1 = asarray([0], dtype=a2.dtype)
if len(a2) == 0:
a2 = asarray([0.])
val = convolve(a1, a2, mode='full')
return val
a2 = asarray([0], dtype=a1.dtype)
return convolve(a1, a2, mode='full')
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
def polydiv(u, v, *, trim_leading_zeros=False):