mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Improved docs for polynomial arithmetic functions of jax.numpy
This commit is contained in:
parent
1949691daa
commit
61e1d560d8
@ -31,7 +31,7 @@ from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
|
||||
from jax._src.numpy.reductions import all
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, implements)
|
||||
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@ -431,9 +431,55 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
|
||||
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll)
|
||||
return y
|
||||
|
||||
@implements(np.polyadd)
|
||||
|
||||
@jit
|
||||
def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||
r"""Returns the sum of the two polynomials.
|
||||
|
||||
JAX implementation of :func:`numpy.polyadd`.
|
||||
|
||||
Args:
|
||||
a1: Array of polynomial coefficients.
|
||||
a2: Array of polynomial coefficients.
|
||||
|
||||
Returns:
|
||||
An array containing the coefficients of the sum of input polynomials.
|
||||
|
||||
Note:
|
||||
:func:`jax.numpy.polyadd` only accepts arrays as input unlike
|
||||
:func:`numpy.polyadd` which accepts scalar inputs as well.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
|
||||
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
|
||||
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
|
||||
division.
|
||||
|
||||
Example:
|
||||
>>> x1 = jnp.array([2, 3])
|
||||
>>> x2 = jnp.array([5, 4, 1])
|
||||
>>> jnp.polyadd(x1, x2)
|
||||
Array([5, 6, 4], dtype=int32)
|
||||
|
||||
>>> x3 = jnp.array([[2, 3, 1]])
|
||||
>>> x4 = jnp.array([[5, 7, 3],
|
||||
... [8, 2, 6]])
|
||||
>>> jnp.polyadd(x3, x4)
|
||||
Array([[ 5, 7, 3],
|
||||
[10, 5, 7]], dtype=int32)
|
||||
|
||||
>>> x5 = jnp.array([1, 3, 5])
|
||||
>>> x6 = jnp.array([[5, 7, 9],
|
||||
... [8, 6, 4]])
|
||||
>>> jnp.polyadd(x5, x6) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 3) shape=(2,)
|
||||
>>> x7 = jnp.array([2])
|
||||
>>> jnp.polyadd(x6, x7)
|
||||
Array([[ 5, 7, 9],
|
||||
[10, 8, 6]], dtype=int32)
|
||||
"""
|
||||
check_arraylike("polyadd", a1, a2)
|
||||
a1_arr, a2_arr = promote_dtypes(a1, a2)
|
||||
del a1, a2
|
||||
@ -561,16 +607,60 @@ def polyder(p: ArrayLike, m: int = 1) -> Array:
|
||||
return p_arr[:-m] * coeff[::-1]
|
||||
|
||||
|
||||
_LEADING_ZEROS_DOC = """\
|
||||
Setting trim_leading_zeros=True makes the output match that of numpy.
|
||||
But prevents the function from being able to be used in compiled code.
|
||||
Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be
|
||||
considered zero may lead to inconsistent results between NumPy and JAX, and even between different
|
||||
JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.
|
||||
"""
|
||||
|
||||
@implements(np.polymul, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
|
||||
r"""Returns the product of two polynomials.
|
||||
|
||||
JAX implementation of :func:`numpy.polymul`.
|
||||
|
||||
Args:
|
||||
a1: 1D array of polynomial coefficients.
|
||||
a2: 1D array of polynomial coefficients.
|
||||
trim_leading_zeros: Default is ``False``. If ``True`` removes the leading
|
||||
zeros in the return value to match the result of numpy. But prevents the
|
||||
function from being able to be used in compiled code. Due to differences
|
||||
in accumulation of floating point arithmetic errors, the cutoff for values
|
||||
to be considered zero may lead to inconsistent results between NumPy and
|
||||
JAX, and even between different JAX backends. The result may lead to
|
||||
inconsistent output shapes when ``trim_leading_zeros=True``.
|
||||
|
||||
Returns:
|
||||
An array of the coefficients of the product of the two polynomials. The dtype
|
||||
of the output is always promoted to inexact.
|
||||
|
||||
Note:
|
||||
:func:`jax.numpy.polymul` only accepts arrays as input unlike
|
||||
:func:`numpy.polymul` which accepts scalar inputs as well.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
|
||||
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
|
||||
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
|
||||
division.
|
||||
|
||||
Example:
|
||||
>>> x1 = np.array([2, 1, 0])
|
||||
>>> x2 = np.array([0, 5, 0, 3])
|
||||
>>> np.polymul(x1, x2)
|
||||
array([10, 5, 6, 3, 0])
|
||||
>>> jnp.polymul(x1, x2)
|
||||
Array([ 0., 10., 5., 6., 3., 0.], dtype=float32)
|
||||
|
||||
If ``trim_leading_zeros=True``, the result matches with ``np.polymul``'s.
|
||||
|
||||
>>> jnp.polymul(x1, x2, trim_leading_zeros=True)
|
||||
Array([10., 5., 6., 3., 0.], dtype=float32)
|
||||
|
||||
For input arrays of dtype ``complex``:
|
||||
|
||||
>>> x3 = np.array([2., 1+2j, 1-2j])
|
||||
>>> x4 = np.array([0, 5, 0, 3])
|
||||
>>> np.polymul(x3, x4)
|
||||
array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j])
|
||||
>>> jnp.polymul(x3, x4)
|
||||
Array([ 0. +0.j, 10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
|
||||
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
|
||||
Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
|
||||
"""
|
||||
check_arraylike("polymul", a1, a2)
|
||||
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
|
||||
del a1, a2
|
||||
@ -582,8 +672,49 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
|
||||
a2_arr = asarray([0], dtype=a1_arr.dtype)
|
||||
return convolve(a1_arr, a2_arr, mode='full')
|
||||
|
||||
@implements(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
|
||||
|
||||
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
|
||||
r"""Returns the quotient and remainder of polynomial division.
|
||||
|
||||
JAX implementation of :func:`numpy.polydiv`.
|
||||
|
||||
Args:
|
||||
u: Array of dividend polynomial coefficients.
|
||||
v: Array of divisor polynomial coefficients.
|
||||
trim_leading_zeros: Default is ``False``. If ``True`` removes the leading
|
||||
zeros in the return value to match the result of numpy. But prevents the
|
||||
function from being able to be used in compiled code. Due to differences
|
||||
in accumulation of floating point arithmetic errors, the cutoff for values
|
||||
to be considered zero may lead to inconsistent results between NumPy and
|
||||
JAX, and even between different JAX backends. The result may lead to
|
||||
inconsistent output shapes when ``trim_leading_zeros=True``.
|
||||
|
||||
Returns:
|
||||
A tuple of quotient and remainder arrays. The dtype of the output is always
|
||||
promoted to inexact.
|
||||
|
||||
Note:
|
||||
:func:`jax.numpy.polydiv` only accepts arrays as input unlike
|
||||
:func:`numpy.polydiv` which accepts scalar inputs as well.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
|
||||
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
|
||||
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
|
||||
|
||||
Example:
|
||||
>>> x1 = jnp.array([5, 7, 9])
|
||||
>>> x2 = jnp.array([4, 1])
|
||||
>>> np.polydiv(x1, x2)
|
||||
(array([1.25 , 1.4375]), array([7.5625]))
|
||||
>>> jnp.polydiv(x1, x2)
|
||||
(Array([1.25 , 1.4375], dtype=float32), Array([0. , 0. , 7.5625], dtype=float32))
|
||||
|
||||
If ``trim_leading_zeros=True``, the result matches with ``np.polydiv``'s.
|
||||
|
||||
>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
|
||||
(Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))
|
||||
"""
|
||||
check_arraylike("polydiv", u, v)
|
||||
u_arr, v_arr = promote_dtypes_inexact(u, v)
|
||||
del u, v
|
||||
@ -600,9 +731,55 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
|
||||
u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f')
|
||||
return q, u_arr
|
||||
|
||||
@implements(np.polysub)
|
||||
|
||||
@jit
|
||||
def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||
r"""Returns the difference of two polynomials.
|
||||
|
||||
JAX implementation of :func:`numpy.polysub`.
|
||||
|
||||
Args:
|
||||
a1: Array of minuend polynomial coefficients.
|
||||
a2: Array of subtrahend polynomial coefficients.
|
||||
|
||||
Returns:
|
||||
An array containing the coefficients of the difference of two polynomials.
|
||||
|
||||
Note:
|
||||
:func:`jax.numpy.polysub` only accepts arrays as input unlike
|
||||
:func:`numpy.polysub` which accepts scalar inputs as well.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
|
||||
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
|
||||
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
|
||||
division.
|
||||
|
||||
Example:
|
||||
>>> x1 = jnp.array([2, 3])
|
||||
>>> x2 = jnp.array([5, 4, 1])
|
||||
>>> jnp.polysub(x1, x2)
|
||||
Array([-5, -2, 2], dtype=int32)
|
||||
|
||||
>>> x3 = jnp.array([[2, 3, 1]])
|
||||
>>> x4 = jnp.array([[5, 7, 3],
|
||||
... [8, 2, 6]])
|
||||
>>> jnp.polysub(x3, x4)
|
||||
Array([[-5, -7, -3],
|
||||
[-6, 1, -5]], dtype=int32)
|
||||
|
||||
>>> x5 = jnp.array([1, 3, 5])
|
||||
>>> x6 = jnp.array([[5, 7, 9],
|
||||
... [8, 6, 4]])
|
||||
>>> jnp.polysub(x5, x6) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 3) shape=(2,)
|
||||
>>> x7 = jnp.array([2])
|
||||
>>> jnp.polysub(x6, x7)
|
||||
Array([[5, 7, 9],
|
||||
[6, 4, 2]], dtype=int32)
|
||||
"""
|
||||
check_arraylike("polysub", a1, a2)
|
||||
a1, a2 = promote_dtypes(a1, a2)
|
||||
return polyadd(a1, -a2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user