mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Test that jax.numpy docstrings include examples
This commit is contained in:
parent
d63afd8438
commit
aa551e66c5
@ -317,7 +317,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array:
|
||||
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
|
||||
coefficients.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
|
||||
Scalar inputs:
|
||||
|
||||
@ -407,7 +407,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
|
||||
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
|
||||
coefficients.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> p = jnp.array([2, 5, 1])
|
||||
>>> jnp.polyval(p, 3)
|
||||
Array(34., dtype=float32)
|
||||
@ -455,7 +455,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
|
||||
division.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> x1 = jnp.array([2, 3])
|
||||
>>> x2 = jnp.array([5, 4, 1])
|
||||
>>> jnp.polyadd(x1, x2)
|
||||
@ -637,7 +637,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
|
||||
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
|
||||
division.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> x1 = np.array([2, 1, 0])
|
||||
>>> x2 = np.array([0, 5, 0, 3])
|
||||
>>> np.polymul(x1, x2)
|
||||
@ -702,7 +702,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
|
||||
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
|
||||
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> x1 = jnp.array([5, 7, 9])
|
||||
>>> x2 = jnp.array([4, 1])
|
||||
>>> np.polydiv(x1, x2)
|
||||
@ -755,7 +755,7 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
|
||||
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
|
||||
division.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> x1 = jnp.array([2, 3])
|
||||
>>> x2 = jnp.array([5, 4, 1])
|
||||
>>> jnp.polysub(x1, x2)
|
||||
|
@ -652,12 +652,26 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
JAX implementation of :obj:`numpy.add`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
This function provides the implementation of the ``+`` operator for
|
||||
JAX arrays.
|
||||
|
||||
Args:
|
||||
x, y: arrays to add. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise addition.
|
||||
|
||||
Examples:
|
||||
Calling ``add`` explicitly:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.add(x, 10)
|
||||
Array([10, 11, 12, 13], dtype=int32)
|
||||
|
||||
Calling ``add`` via the ``+`` operator:
|
||||
|
||||
>>> x + 10
|
||||
Array([10, 11, 12, 13], dtype=int32)
|
||||
"""
|
||||
x, y = promote_args("add", x, y)
|
||||
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
|
||||
@ -668,12 +682,26 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
JAX implementation of :obj:`numpy.multiply`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
This function provides the implementation of the ``*`` operator for
|
||||
JAX arrays.
|
||||
|
||||
Args:
|
||||
x, y: arrays to multiply. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise multiplication.
|
||||
|
||||
Examples:
|
||||
Calling ``multiply`` explicitly:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.multiply(x, 10)
|
||||
Array([ 0, 10, 20, 30], dtype=int32)
|
||||
|
||||
Calling ``multiply`` via the ``*`` operator:
|
||||
|
||||
>>> x * 10
|
||||
Array([ 0, 10, 20, 30], dtype=int32)
|
||||
"""
|
||||
x, y = promote_args("multiply", x, y)
|
||||
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
|
||||
@ -684,12 +712,26 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
This function provides the implementation of the ``&`` operator for
|
||||
JAX arrays.
|
||||
|
||||
Args:
|
||||
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise bitwise AND.
|
||||
|
||||
Examples:
|
||||
Calling ``bitwise_and`` explicitly:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.bitwise_and(x, 1)
|
||||
Array([0, 1, 0, 1], dtype=int32)
|
||||
|
||||
Calling ``bitwise_and`` via the ``&`` operator:
|
||||
|
||||
>>> x & 1
|
||||
Array([0, 1, 0, 1], dtype=int32)
|
||||
"""
|
||||
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
|
||||
|
||||
@ -699,12 +741,26 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
This function provides the implementation of the ``|`` operator for
|
||||
JAX arrays.
|
||||
|
||||
Args:
|
||||
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise bitwise OR.
|
||||
|
||||
Examples:
|
||||
Calling ``bitwise_or`` explicitly:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.bitwise_or(x, 1)
|
||||
Array([1, 1, 3, 3], dtype=int32)
|
||||
|
||||
Calling ``bitwise_or`` via the ``|`` operator:
|
||||
|
||||
>>> x | 1
|
||||
Array([1, 1, 3, 3], dtype=int32)
|
||||
"""
|
||||
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
|
||||
|
||||
@ -714,12 +770,26 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
|
||||
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
|
||||
This function provides the implementation of the ``^`` operator for
|
||||
JAX arrays.
|
||||
|
||||
Args:
|
||||
x, y: integer or boolean arrays. Must be broadcastable to a common shape.
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise bitwise XOR.
|
||||
|
||||
Examples:
|
||||
Calling ``bitwise_xor`` explicitly:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.bitwise_xor(x, 1)
|
||||
Array([1, 0, 3, 2], dtype=int32)
|
||||
|
||||
Calling ``bitwise_xor`` via the ``^`` operator:
|
||||
|
||||
>>> x ^ 1
|
||||
Array([1, 0, 3, 2], dtype=int32)
|
||||
"""
|
||||
return lax.bitwise_xor(*promote_args("bitwise_xor", x, y))
|
||||
|
||||
@ -958,6 +1028,11 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise logical AND.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.logical_and(x, 1)
|
||||
Array([False, True, True, True], dtype=bool)
|
||||
"""
|
||||
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))
|
||||
|
||||
@ -973,6 +1048,11 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise logical OR.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.logical_or(x, 1)
|
||||
Array([ True, True, True, True], dtype=bool)
|
||||
"""
|
||||
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))
|
||||
|
||||
@ -988,6 +1068,11 @@ def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
|
||||
|
||||
Returns:
|
||||
Array containing the result of the element-wise logical XOR.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.arange(4)
|
||||
>>> jnp.logical_xor(x, 1)
|
||||
Array([ True, False, False, False], dtype=bool)
|
||||
"""
|
||||
return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y)))
|
||||
|
||||
@ -1373,7 +1458,7 @@ def rint(x: ArrayLike, /) -> Array:
|
||||
If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round
|
||||
to the nearest even integer.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
>>> x1 = jnp.array([5, 4, 7])
|
||||
>>> jnp.rint(x1)
|
||||
Array([5., 4., 7.], dtype=float32)
|
||||
|
@ -215,6 +215,7 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
|
||||
Returns:
|
||||
Vectorized version of the given function.
|
||||
|
||||
Examples:
|
||||
Here are a few examples of how one could write vectorized linear algebra
|
||||
routines using :func:`vectorize`:
|
||||
|
||||
|
@ -6341,6 +6341,8 @@ class NumpyDocTests(jtu.JaxTestCase):
|
||||
self.assertNotEmpty(doc)
|
||||
self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}")
|
||||
self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}")
|
||||
if name not in ["frompyfunc", "isdtype", "promote_types"]:
|
||||
self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False])
|
||||
|
Loading…
x
Reference in New Issue
Block a user