mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[x64] make jnp.poly* functions work under strict dtype promotion
This commit is contained in:
parent
bab8520d0c
commit
30b687c486
@ -256,7 +256,9 @@ def polyint(p, m=1, k=None):
|
||||
if m == 0:
|
||||
return p
|
||||
else:
|
||||
coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
|
||||
grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis]
|
||||
- arange(m, dtype=p.dtype)[:, np.newaxis])
|
||||
coeff = maximum(1, grid).prod(0)[::-1]
|
||||
return true_divide(concatenate((p, k)), coeff)
|
||||
|
||||
|
||||
@ -270,8 +272,9 @@ def polyder(p, m=1):
|
||||
raise ValueError("Order of derivative must be positive")
|
||||
if m == 0:
|
||||
return p
|
||||
coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
|
||||
return p[:-m] * coeff
|
||||
coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis]
|
||||
- arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0)
|
||||
return p[:-m] * coeff[::-1]
|
||||
|
||||
|
||||
_LEADING_ZEROS_DOC = """\
|
||||
@ -289,11 +292,10 @@ def polymul(a1, a2, *, trim_leading_zeros=False):
|
||||
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
|
||||
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
|
||||
if len(a1) == 0:
|
||||
a1 = asarray([0.])
|
||||
a1 = asarray([0], dtype=a2.dtype)
|
||||
if len(a2) == 0:
|
||||
a2 = asarray([0.])
|
||||
val = convolve(a1, a2, mode='full')
|
||||
return val
|
||||
a2 = asarray([0], dtype=a1.dtype)
|
||||
return convolve(a1, a2, mode='full')
|
||||
|
||||
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polydiv(u, v, *, trim_leading_zeros=False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user