Improved docs for polynomial arithmetic functions of jax.numpy

This commit is contained in:
rajasekharporeddy 2024-07-02 20:04:22 +05:30
parent 1949691daa
commit 61e1d560d8

View File

@ -31,7 +31,7 @@ from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
from jax._src.numpy.reductions import all
from jax._src.numpy import linalg
from jax._src.numpy.util import (
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where, implements)
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
from jax._src.typing import Array, ArrayLike
@ -431,9 +431,55 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll)
return y
@implements(np.polyadd)
@jit
def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the sum of the two polynomials.
JAX implementation of :func:`numpy.polyadd`.
Args:
a1: Array of polynomial coefficients.
a2: Array of polynomial coefficients.
Returns:
An array containing the coefficients of the sum of input polynomials.
Note:
:func:`jax.numpy.polyadd` only accepts arrays as input unlike
:func:`numpy.polyadd` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polyadd(x1, x2)
Array([5, 6, 4], dtype=int32)
>>> x3 = jnp.array([[2, 3, 1]])
>>> x4 = jnp.array([[5, 7, 3],
... [8, 2, 6]])
>>> jnp.polyadd(x3, x4)
Array([[ 5, 7, 3],
[10, 5, 7]], dtype=int32)
>>> x5 = jnp.array([1, 3, 5])
>>> x6 = jnp.array([[5, 7, 9],
... [8, 6, 4]])
>>> jnp.polyadd(x5, x6) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 3) shape=(2,)
>>> x7 = jnp.array([2])
>>> jnp.polyadd(x6, x7)
Array([[ 5, 7, 9],
[10, 8, 6]], dtype=int32)
"""
check_arraylike("polyadd", a1, a2)
a1_arr, a2_arr = promote_dtypes(a1, a2)
del a1, a2
@ -561,16 +607,60 @@ def polyder(p: ArrayLike, m: int = 1) -> Array:
return p_arr[:-m] * coeff[::-1]
_LEADING_ZEROS_DOC = """\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
Due to differences in accumulation of floating point arithmetic errors, the cutoff for values to be
considered zero may lead to inconsistent results between NumPy and JAX, and even between different
JAX backends. The result may lead to inconsistent output shapes when trim_leading_zeros=True.
"""
@implements(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
r"""Returns the product of two polynomials.
JAX implementation of :func:`numpy.polymul`.
Args:
a1: 1D array of polynomial coefficients.
a2: 1D array of polynomial coefficients.
trim_leading_zeros: Default is ``False``. If ``True`` removes the leading
zeros in the return value to match the result of numpy. But prevents the
function from being able to be used in compiled code. Due to differences
in accumulation of floating point arithmetic errors, the cutoff for values
to be considered zero may lead to inconsistent results between NumPy and
JAX, and even between different JAX backends. The result may lead to
inconsistent output shapes when ``trim_leading_zeros=True``.
Returns:
An array of the coefficients of the product of the two polynomials. The dtype
of the output is always promoted to inexact.
Note:
:func:`jax.numpy.polymul` only accepts arrays as input unlike
:func:`numpy.polymul` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
>>> x1 = np.array([2, 1, 0])
>>> x2 = np.array([0, 5, 0, 3])
>>> np.polymul(x1, x2)
array([10, 5, 6, 3, 0])
>>> jnp.polymul(x1, x2)
Array([ 0., 10., 5., 6., 3., 0.], dtype=float32)
If ``trim_leading_zeros=True``, the result matches with ``np.polymul``'s.
>>> jnp.polymul(x1, x2, trim_leading_zeros=True)
Array([10., 5., 6., 3., 0.], dtype=float32)
For input arrays of dtype ``complex``:
>>> x3 = np.array([2., 1+2j, 1-2j])
>>> x4 = np.array([0, 5, 0, 3])
>>> np.polymul(x3, x4)
array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j])
>>> jnp.polymul(x3, x4)
Array([ 0. +0.j, 10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
"""
check_arraylike("polymul", a1, a2)
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
del a1, a2
@ -582,8 +672,49 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
a2_arr = asarray([0], dtype=a1_arr.dtype)
return convolve(a1_arr, a2_arr, mode='full')
@implements(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
r"""Returns the quotient and remainder of polynomial division.
JAX implementation of :func:`numpy.polydiv`.
Args:
u: Array of dividend polynomial coefficients.
v: Array of divisor polynomial coefficients.
trim_leading_zeros: Default is ``False``. If ``True`` removes the leading
zeros in the return value to match the result of numpy. But prevents the
function from being able to be used in compiled code. Due to differences
in accumulation of floating point arithmetic errors, the cutoff for values
to be considered zero may lead to inconsistent results between NumPy and
JAX, and even between different JAX backends. The result may lead to
inconsistent output shapes when ``trim_leading_zeros=True``.
Returns:
A tuple of quotient and remainder arrays. The dtype of the output is always
promoted to inexact.
Note:
:func:`jax.numpy.polydiv` only accepts arrays as input unlike
:func:`numpy.polydiv` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
Example:
>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
(array([1.25 , 1.4375]), array([7.5625]))
>>> jnp.polydiv(x1, x2)
(Array([1.25 , 1.4375], dtype=float32), Array([0. , 0. , 7.5625], dtype=float32))
If ``trim_leading_zeros=True``, the result matches with ``np.polydiv``'s.
>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))
"""
check_arraylike("polydiv", u, v)
u_arr, v_arr = promote_dtypes_inexact(u, v)
del u, v
@ -600,9 +731,55 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f')
return q, u_arr
@implements(np.polysub)
@jit
def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the difference of two polynomials.
JAX implementation of :func:`numpy.polysub`.
Args:
a1: Array of minuend polynomial coefficients.
a2: Array of subtrahend polynomial coefficients.
Returns:
An array containing the coefficients of the difference of two polynomials.
Note:
:func:`jax.numpy.polysub` only accepts arrays as input unlike
:func:`numpy.polysub` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polysub(x1, x2)
Array([-5, -2, 2], dtype=int32)
>>> x3 = jnp.array([[2, 3, 1]])
>>> x4 = jnp.array([[5, 7, 3],
... [8, 2, 6]])
>>> jnp.polysub(x3, x4)
Array([[-5, -7, -3],
[-6, 1, -5]], dtype=int32)
>>> x5 = jnp.array([1, 3, 5])
>>> x6 = jnp.array([[5, 7, 9],
... [8, 6, 4]])
>>> jnp.polysub(x5, x6) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 3) shape=(2,)
>>> x7 = jnp.array([2])
>>> jnp.polysub(x6, x7)
Array([[5, 7, 9],
[6, 4, 2]], dtype=int32)
"""
check_arraylike("polysub", a1, a2)
a1, a2 = promote_dtypes(a1, a2)
return polyadd(a1, -a2)