mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix Typos
This commit is contained in:
parent
595a620804
commit
b93da3873b
@ -801,7 +801,7 @@ def ragged_dot(
|
||||
group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
|
||||
precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
||||
preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
||||
group_offset: Optional. (1,) shaped array that ndicates the group in group_sizes to start computing from. If not specified, defaults to [0].
|
||||
group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].
|
||||
|
||||
Results:
|
||||
(m, n) shaped array with preferred_element_type element type.
|
||||
@ -1444,7 +1444,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
|
||||
If not specified, the output will have the same sharding as the input,
|
||||
with a few exceptions/limitations in particular:
|
||||
1. Sharding is not available during tracing, thus this will rely on jit.
|
||||
2. If x is weakly typed or uncomitted, will use default sharding.
|
||||
2. If x is weakly typed or uncommitted, will use default sharding.
|
||||
3. Shape is not None and is different from x.shape, default will be used.
|
||||
|
||||
Returns:
|
||||
@ -5116,7 +5116,7 @@ def remaining(original, *removed_lists):
|
||||
|
||||
|
||||
def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None:
|
||||
"""Turns an API precision specification, into a pair of enumeration values.
|
||||
"""Turns an API precision specification into a pair of enumeration values.
|
||||
|
||||
The API can take the precision as a string, or int, and either as a single
|
||||
value to apply to both operands, or as a sequence of two values.
|
||||
|
@ -1540,10 +1540,10 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
|
||||
When ``False`` (the default value), the size of dimension in
|
||||
``scatter_dimension`` must match the size of axis ``axis_name`` (or the
|
||||
group size if ``axis_index_groups`` is given). After scattering the
|
||||
all-reduce result along ``scatter_dimension``, the output is sequeezed by
|
||||
all-reduce result along ``scatter_dimension``, the output is squeezed by
|
||||
removing ``scatter_dimension``, so the result has lower rank than the
|
||||
input. When ``True``, the size of dimension in ``scatter_dimension`` must
|
||||
be dividible by the size of axis ``axis_name`` (or the group size if
|
||||
be divisible by the size of axis ``axis_name`` (or the group size if
|
||||
``axis_index_groups`` is given), and the ``scatter_dimension`` axis is
|
||||
preserved (so the result has the same rank as the input).
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
"""A JIT-compatible library for QDWH-based singular value decomposition.
|
||||
|
||||
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||
iteration implemented through QR decmopositions is numerically stable and does
|
||||
iteration implemented through QR decompositions is numerically stable and does
|
||||
not require solving a linear system involving the iteration matrix or
|
||||
computing its inversion. This is desirable for multicore and heterogeneous
|
||||
computing systems.
|
||||
@ -59,7 +59,7 @@ def _svd_tall_and_square_input(
|
||||
Args:
|
||||
a: A matrix of shape `m x n` with `m >= n`.
|
||||
hermitian: True if `a` is Hermitian.
|
||||
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
|
||||
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
|
||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||
|
||||
Returns:
|
||||
@ -126,11 +126,11 @@ def svd(
|
||||
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
|
||||
respectively. If False, the shapes are `m x k` and `k x n`, respectively,
|
||||
where `k = min(m, n)`.
|
||||
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
|
||||
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
|
||||
hermitian: True if `a` is Hermitian.
|
||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
||||
indices of singular componenets to compute. For example, if
|
||||
indices of singular components to compute. For example, if
|
||||
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
|
||||
singular values (and their singular vectors if `compute_uv` is true.
|
||||
|
||||
|
@ -795,7 +795,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
|
||||
Args:
|
||||
a: array of shape ``(..., M, M)``, containing the Hermitian (if complex)
|
||||
or symmetric (if real) matrix.
|
||||
UPLO: specifies whether the calculation isdone with the lower triangular
|
||||
UPLO: specifies whether the calculation is done with the lower triangular
|
||||
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
|
||||
symmetrize_input: if True (default) then input is symmetrized, which leads
|
||||
to better behavior under automatic differentiation.
|
||||
@ -1249,7 +1249,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
|
||||
|
||||
See also:
|
||||
- :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API
|
||||
- :func:`jax.lax.linalg.qr`: XLA-style QR decompositon API
|
||||
- :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API
|
||||
|
||||
Examples:
|
||||
Compute the QR decomposition of a matrix:
|
||||
@ -1443,7 +1443,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
|
||||
JAX implementation of :func:`numpy.linalg.cross`
|
||||
|
||||
Args:
|
||||
x1: N-dimesional array, with ``x1.shape[axis] == 3``
|
||||
x1: N-dimensional array, with ``x1.shape[axis] == 3``
|
||||
x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes
|
||||
broadcast-compatible with ``x1``.
|
||||
axis: axis along which to take the cross product (default: -1).
|
||||
|
@ -528,7 +528,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
return_index: if True, also return the indices in ``ar`` where each value occurs
|
||||
return_inverse: if True, also return the indices that can be used to reconstruct
|
||||
``ar`` from the unique values.
|
||||
return_counts: if True, also return the number of occurances of each unique value.
|
||||
return_counts: if True, also return the number of occurrences of each unique value.
|
||||
axis: if specified, compute unique values along the specified axis. If None (default),
|
||||
then flatten ``ar`` before computing the unique values.
|
||||
equal_nan: if True, consider NaN values equivalent when determining uniqueness.
|
||||
@ -546,8 +546,8 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``.
|
||||
- ``unique_index``:
|
||||
*(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains
|
||||
the indices of the first occurance of each unique value in ``ar``. For 1D inputs,
|
||||
``ar[unique_index]`` is equivlent to ``unique_values``.
|
||||
the indices of the first occurrence of each unique value in ``ar``. For 1D inputs,
|
||||
``ar[unique_index]`` is equivalent to ``unique_values``.
|
||||
- ``unique_inverse``:
|
||||
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
|
||||
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
|
||||
@ -555,7 +555,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
``unique_values[unique_inverse]`` is equivalent to ``ar``.
|
||||
- ``unique_counts``:
|
||||
*(returned only if return_counts is True)* An array of shape ``(n_unique,)``.
|
||||
Contains the number of occurances of each unique value in ``ar``.
|
||||
Contains the number of occurrences of each unique value in ``ar``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``.
|
||||
@ -619,7 +619,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
**Returning indices**
|
||||
|
||||
If you set ``return_index=True``, then ``unique`` returns the indices of the
|
||||
first occurance of each unique value:
|
||||
first occurrence of each unique value:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> values, indices = jnp.unique(x, return_index=True)
|
||||
@ -660,7 +660,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
|
||||
**Returning counts**
|
||||
|
||||
If you set ``return_counts=True``, then ``unique`` returns the number of occurances
|
||||
If you set ``return_counts=True``, then ``unique`` returns the number of occurrences
|
||||
within the input for every unique value:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
@ -671,7 +671,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
[2 2 1]
|
||||
|
||||
For multi-dimensional arrays, this also returns a 1D array of counts
|
||||
indicating number of occurances along the specified axis:
|
||||
indicating number of occurrences along the specified axis:
|
||||
|
||||
>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
|
||||
>>> print(values)
|
||||
@ -748,13 +748,13 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
|
||||
- ``values``:
|
||||
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||
- ``indices``:
|
||||
An array of shape ``(n_unique,)``. Contains the indices of the first occurance of
|
||||
each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivlent to ``values``.
|
||||
An array of shape ``(n_unique,)``. Contains the indices of the first occurrence of
|
||||
each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivalent to ``values``.
|
||||
- ``inverse_indices``:
|
||||
An array of shape ``x.shape``. Contains the indices within ``values`` of each value
|
||||
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
|
||||
- ``counts``:
|
||||
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
|
||||
An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique
|
||||
value in ``x``.
|
||||
|
||||
See also:
|
||||
@ -770,7 +770,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
|
||||
>>> result = jnp.unique_all(x)
|
||||
|
||||
The result is a :class:`~typing.NamedTuple` with four named attributes.
|
||||
The ``values`` attribue contains the unique values from the array:
|
||||
The ``values`` attribute contains the unique values from the array:
|
||||
|
||||
>>> result.values
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
@ -829,7 +829,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
|
||||
- ``values``:
|
||||
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||
- ``counts``:
|
||||
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
|
||||
An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique
|
||||
value in ``x``.
|
||||
|
||||
See also:
|
||||
@ -846,7 +846,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
|
||||
>>> result = jnp.unique_counts(x)
|
||||
|
||||
The result is a :class:`~typing.NamedTuple` with two named attributes.
|
||||
The ``values`` attribue contains the unique values from the array:
|
||||
The ``values`` attribute contains the unique values from the array:
|
||||
|
||||
>>> result.values
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
|
@ -354,7 +354,7 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
|
||||
"""Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
|
||||
|
||||
Args:
|
||||
func : a callable that takes `nin` scalar arguments and return `nout` outputs.
|
||||
func : a callable that takes `nin` scalar arguments and returns `nout` outputs.
|
||||
nin: integer specifying the number of scalar inputs
|
||||
nout: integer specifying the number of scalar outputs
|
||||
identity: (optional) a scalar specifying the identity of the operation, if any.
|
||||
|
Loading…
x
Reference in New Issue
Block a user