mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix some typos in lax_numpy.py
This commit is contained in:
parent
3d389a7fb4
commit
0a8f46a6ca
@ -1375,9 +1375,9 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
input is returned as is.
|
||||
axis: int, optional, default=-1. Specifies the axis along which the difference
|
||||
is computed. The difference is computed along ``axis -1`` by default.
|
||||
prepend: scalar or array, optional, defualt=None. Specifies the values to be
|
||||
prepend: scalar or array, optional, default=None. Specifies the values to be
|
||||
prepended along ``axis`` before computing the difference.
|
||||
append: scalar or array, optional, defualt=None. Specifies the values to be
|
||||
append: scalar or array, optional, default=None. Specifies the values to be
|
||||
appended along ``axis`` before computing the difference.
|
||||
|
||||
Returns:
|
||||
@ -3677,7 +3677,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
|
||||
f"and larger or equal than the padding length (= {padding}). "
|
||||
f"Error while handling {'left' if before else 'right'} padding on axis {i}.")
|
||||
try:
|
||||
# We check that we can determine all comparisions.
|
||||
# We check that we can determine all comparisons.
|
||||
offset = 1 if (mode == "reflect" and axis_size > 1) else 0
|
||||
has_poly_dim = not core.is_constant_shape((axis_size, padding))
|
||||
# For shape polymorphism, ensure the loop below ends after 1 iteration
|
||||
@ -5683,7 +5683,7 @@ def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float,
|
||||
a length that is an integer multiple of the dtype element size, or
|
||||
it must be an object exporting the `Python buffer interface`_.
|
||||
dtype: optional. Desired data type for the array. Default is ``float64``.
|
||||
This specifes the dtype used to parse the buffer, but note that after parsing,
|
||||
This specifies the dtype used to parse the buffer, but note that after parsing,
|
||||
64-bit values will be cast to 32-bit JAX arrays if the ``jax_enable_x64``
|
||||
flag is set to ``False``.
|
||||
count: optional integer specifying the number of items to read from the buffer.
|
||||
@ -6512,7 +6512,7 @@ def i0(x: ArrayLike) -> Array:
|
||||
are not supported.
|
||||
|
||||
Returns:
|
||||
An array containing the corresponding vlaues of the modified Bessel function
|
||||
An array containing the corresponding values of the modified Bessel function
|
||||
of ``x``.
|
||||
|
||||
See also:
|
||||
@ -7708,7 +7708,7 @@ def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array:
|
||||
- ``fb`` - trims both leading and trailing zeros.
|
||||
|
||||
Returns:
|
||||
An array containig the trimmed input with same dtype as ``filt``.
|
||||
An array containing the trimmed input with same dtype as ``filt``.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user