Test that jax.numpy docstrings include examples

This commit is contained in:
Jake VanderPlas 2024-09-21 07:39:17 -07:00
parent d63afd8438
commit aa551e66c5
4 changed files with 130 additions and 42 deletions

View File

@ -317,7 +317,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array:
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
Examples:
Scalar inputs:
@ -407,7 +407,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
Examples:
>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)
@ -455,7 +455,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
Examples:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polyadd(x1, x2)
@ -637,7 +637,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
Examples:
>>> x1 = np.array([2, 1, 0])
>>> x2 = np.array([0, 5, 0, 3])
>>> np.polymul(x1, x2)
@ -702,7 +702,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
Example:
Examples:
>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
@ -755,7 +755,7 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
Examples:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polysub(x1, x2)

View File

@ -652,12 +652,26 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.add`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``+`` operator for
JAX arrays.
Args:
x, y: arrays to add. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise addition.
Examples:
Calling ``add`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.add(x, 10)
Array([10, 11, 12, 13], dtype=int32)
Calling ``add`` via the ``+`` operator:
>>> x + 10
Array([10, 11, 12, 13], dtype=int32)
"""
x, y = promote_args("add", x, y)
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
@ -668,12 +682,26 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.multiply`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``*`` operator for
JAX arrays.
Args:
x, y: arrays to multiply. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise multiplication.
Examples:
Calling ``multiply`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.multiply(x, 10)
Array([ 0, 10, 20, 30], dtype=int32)
Calling ``multiply`` via the ``*`` operator:
>>> x * 10
Array([ 0, 10, 20, 30], dtype=int32)
"""
x, y = promote_args("multiply", x, y)
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
@ -684,12 +712,26 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``&`` operator for
JAX arrays.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise AND.
Examples:
Calling ``bitwise_and`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.bitwise_and(x, 1)
Array([0, 1, 0, 1], dtype=int32)
Calling ``bitwise_and`` via the ``&`` operator:
>>> x & 1
Array([0, 1, 0, 1], dtype=int32)
"""
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
@ -699,12 +741,26 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``|`` operator for
JAX arrays.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise OR.
Examples:
Calling ``bitwise_or`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.bitwise_or(x, 1)
Array([1, 1, 3, 3], dtype=int32)
Calling ``bitwise_or`` via the ``|`` operator:
>>> x | 1
Array([1, 1, 3, 3], dtype=int32)
"""
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
@ -714,12 +770,26 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
This function provides the implementation of the ``^`` operator for
JAX arrays.
Args:
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
Returns:
Array containing the result of the element-wise bitwise XOR.
Examples:
Calling ``bitwise_xor`` explicitly:
>>> x = jnp.arange(4)
>>> jnp.bitwise_xor(x, 1)
Array([1, 0, 3, 2], dtype=int32)
Calling ``bitwise_xor`` via the ``^`` operator:
>>> x ^ 1
Array([1, 0, 3, 2], dtype=int32)
"""
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))
@ -958,6 +1028,11 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
Returns:
Array containing the result of the element-wise logical AND.
Examples:
>>> x = jnp.arange(4)
>>> jnp.logical_and(x, 1)
Array([False, True, True, True], dtype=bool)
"""
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))
@ -973,6 +1048,11 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
Returns:
Array containing the result of the element-wise logical OR.
Examples:
>>> x = jnp.arange(4)
>>> jnp.logical_or(x, 1)
Array([ True, True, True, True], dtype=bool)
"""
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))
@ -988,6 +1068,11 @@ def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
Returns:
Array containing the result of the element-wise logical XOR.
Examples:
>>> x = jnp.arange(4)
>>> jnp.logical_xor(x, 1)
Array([ True, False, False, False], dtype=bool)
"""
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))
@ -1373,7 +1458,7 @@ def rint(x: ArrayLike, /) -> Array:
If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round
to the nearest even integer.
Example:
Examples:
>>> x1 = jnp.array([5, 4, 7])
>>> jnp.rint(x1)
Array([5., 4., 7.], dtype=float32)

View File

@ -215,6 +215,7 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
Returns:
Vectorized version of the given function.
Examples:
Here are a few examples of how one could write vectorized linear algebra
routines using :func:`vectorize`:

View File

@ -6341,6 +6341,8 @@ class NumpyDocTests(jtu.JaxTestCase):
self.assertNotEmpty(doc)
self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}")
self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}")
if name not in ["frompyfunc", "isdtype", "promote_types"]:
self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}")
@parameterized.named_parameters(
{"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])