mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix Typos
This commit is contained in:
parent
595a620804
commit
b93da3873b
@ -801,7 +801,7 @@ def ragged_dot(
|
|||||||
group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
|
group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
|
||||||
precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
||||||
preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
|
||||||
group_offset: Optional. (1,) shaped array that ndicates the group in group_sizes to start computing from. If not specified, defaults to [0].
|
group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].
|
||||||
|
|
||||||
Results:
|
Results:
|
||||||
(m, n) shaped array with preferred_element_type element type.
|
(m, n) shaped array with preferred_element_type element type.
|
||||||
@ -1444,7 +1444,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
|
|||||||
If not specified, the output will have the same sharding as the input,
|
If not specified, the output will have the same sharding as the input,
|
||||||
with a few exceptions/limitations in particular:
|
with a few exceptions/limitations in particular:
|
||||||
1. Sharding is not available during tracing, thus this will rely on jit.
|
1. Sharding is not available during tracing, thus this will rely on jit.
|
||||||
2. If x is weakly typed or uncomitted, will use default sharding.
|
2. If x is weakly typed or uncommitted, will use default sharding.
|
||||||
3. Shape is not None and is different from x.shape, default will be used.
|
3. Shape is not None and is different from x.shape, default will be used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -5116,7 +5116,7 @@ def remaining(original, *removed_lists):
|
|||||||
|
|
||||||
|
|
||||||
def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None:
|
def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None:
|
||||||
"""Turns an API precision specification, into a pair of enumeration values.
|
"""Turns an API precision specification into a pair of enumeration values.
|
||||||
|
|
||||||
The API can take the precision as a string, or int, and either as a single
|
The API can take the precision as a string, or int, and either as a single
|
||||||
value to apply to both operands, or as a sequence of two values.
|
value to apply to both operands, or as a sequence of two values.
|
||||||
|
@ -1540,10 +1540,10 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
|
|||||||
When ``False`` (the default value), the size of dimension in
|
When ``False`` (the default value), the size of dimension in
|
||||||
``scatter_dimension`` must match the size of axis ``axis_name`` (or the
|
``scatter_dimension`` must match the size of axis ``axis_name`` (or the
|
||||||
group size if ``axis_index_groups`` is given). After scattering the
|
group size if ``axis_index_groups`` is given). After scattering the
|
||||||
all-reduce result along ``scatter_dimension``, the output is sequeezed by
|
all-reduce result along ``scatter_dimension``, the output is squeezed by
|
||||||
removing ``scatter_dimension``, so the result has lower rank than the
|
removing ``scatter_dimension``, so the result has lower rank than the
|
||||||
input. When ``True``, the size of dimension in ``scatter_dimension`` must
|
input. When ``True``, the size of dimension in ``scatter_dimension`` must
|
||||||
be dividible by the size of axis ``axis_name`` (or the group size if
|
be divisible by the size of axis ``axis_name`` (or the group size if
|
||||||
``axis_index_groups`` is given), and the ``scatter_dimension`` axis is
|
``axis_index_groups`` is given), and the ``scatter_dimension`` axis is
|
||||||
preserved (so the result has the same rank as the input).
|
preserved (so the result has the same rank as the input).
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
"""A JIT-compatible library for QDWH-based singular value decomposition.
|
"""A JIT-compatible library for QDWH-based singular value decomposition.
|
||||||
|
|
||||||
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||||
iteration implemented through QR decmopositions is numerically stable and does
|
iteration implemented through QR decompositions is numerically stable and does
|
||||||
not require solving a linear system involving the iteration matrix or
|
not require solving a linear system involving the iteration matrix or
|
||||||
computing its inversion. This is desirable for multicore and heterogeneous
|
computing its inversion. This is desirable for multicore and heterogeneous
|
||||||
computing systems.
|
computing systems.
|
||||||
@ -59,7 +59,7 @@ def _svd_tall_and_square_input(
|
|||||||
Args:
|
Args:
|
||||||
a: A matrix of shape `m x n` with `m >= n`.
|
a: A matrix of shape `m x n` with `m >= n`.
|
||||||
hermitian: True if `a` is Hermitian.
|
hermitian: True if `a` is Hermitian.
|
||||||
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
|
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
|
||||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -126,11 +126,11 @@ def svd(
|
|||||||
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
|
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
|
||||||
respectively. If False, the shapes are `m x k` and `k x n`, respectively,
|
respectively. If False, the shapes are `m x k` and `k x n`, respectively,
|
||||||
where `k = min(m, n)`.
|
where `k = min(m, n)`.
|
||||||
compute_uv: Whether to compute also `u` and `v` in addition to `s`.
|
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
|
||||||
hermitian: True if `a` is Hermitian.
|
hermitian: True if `a` is Hermitian.
|
||||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||||
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
||||||
indices of singular componenets to compute. For example, if
|
indices of singular components to compute. For example, if
|
||||||
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
|
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
|
||||||
singular values (and their singular vectors if `compute_uv` is true.
|
singular values (and their singular vectors if `compute_uv` is true.
|
||||||
|
|
||||||
|
@ -795,7 +795,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
|
|||||||
Args:
|
Args:
|
||||||
a: array of shape ``(..., M, M)``, containing the Hermitian (if complex)
|
a: array of shape ``(..., M, M)``, containing the Hermitian (if complex)
|
||||||
or symmetric (if real) matrix.
|
or symmetric (if real) matrix.
|
||||||
UPLO: specifies whether the calculation isdone with the lower triangular
|
UPLO: specifies whether the calculation is done with the lower triangular
|
||||||
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
|
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
|
||||||
symmetrize_input: if True (default) then input is symmetrized, which leads
|
symmetrize_input: if True (default) then input is symmetrized, which leads
|
||||||
to better behavior under automatic differentiation.
|
to better behavior under automatic differentiation.
|
||||||
@ -1249,7 +1249,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
|
|||||||
|
|
||||||
See also:
|
See also:
|
||||||
- :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API
|
- :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API
|
||||||
- :func:`jax.lax.linalg.qr`: XLA-style QR decompositon API
|
- :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
Compute the QR decomposition of a matrix:
|
Compute the QR decomposition of a matrix:
|
||||||
@ -1443,7 +1443,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
|
|||||||
JAX implementation of :func:`numpy.linalg.cross`
|
JAX implementation of :func:`numpy.linalg.cross`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x1: N-dimesional array, with ``x1.shape[axis] == 3``
|
x1: N-dimensional array, with ``x1.shape[axis] == 3``
|
||||||
x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes
|
x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes
|
||||||
broadcast-compatible with ``x1``.
|
broadcast-compatible with ``x1``.
|
||||||
axis: axis along which to take the cross product (default: -1).
|
axis: axis along which to take the cross product (default: -1).
|
||||||
|
@ -528,7 +528,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
|||||||
return_index: if True, also return the indices in ``ar`` where each value occurs
|
return_index: if True, also return the indices in ``ar`` where each value occurs
|
||||||
return_inverse: if True, also return the indices that can be used to reconstruct
|
return_inverse: if True, also return the indices that can be used to reconstruct
|
||||||
``ar`` from the unique values.
|
``ar`` from the unique values.
|
||||||
return_counts: if True, also return the number of occurances of each unique value.
|
return_counts: if True, also return the number of occurrences of each unique value.
|
||||||
axis: if specified, compute unique values along the specified axis. If None (default),
|
axis: if specified, compute unique values along the specified axis. If None (default),
|
||||||
then flatten ``ar`` before computing the unique values.
|
then flatten ``ar`` before computing the unique values.
|
||||||
equal_nan: if True, consider NaN values equivalent when determining uniqueness.
|
equal_nan: if True, consider NaN values equivalent when determining uniqueness.
|
||||||
@ -546,8 +546,8 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
|||||||
specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``.
|
specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``.
|
||||||
- ``unique_index``:
|
- ``unique_index``:
|
||||||
*(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains
|
*(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains
|
||||||
the indices of the first occurance of each unique value in ``ar``. For 1D inputs,
|
the indices of the first occurrence of each unique value in ``ar``. For 1D inputs,
|
||||||
``ar[unique_index]`` is equivlent to ``unique_values``.
|
``ar[unique_index]`` is equivalent to ``unique_values``.
|
||||||
- ``unique_inverse``:
|
- ``unique_inverse``:
|
||||||
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
|
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
|
||||||
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
|
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
|
||||||
@ -555,7 +555,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
|||||||
``unique_values[unique_inverse]`` is equivalent to ``ar``.
|
``unique_values[unique_inverse]`` is equivalent to ``ar``.
|
||||||
- ``unique_counts``:
|
- ``unique_counts``:
|
||||||
*(returned only if return_counts is True)* An array of shape ``(n_unique,)``.
|
*(returned only if return_counts is True)* An array of shape ``(n_unique,)``.
|
||||||
Contains the number of occurances of each unique value in ``ar``.
|
Contains the number of occurrences of each unique value in ``ar``.
|
||||||
|
|
||||||
See also:
|
See also:
|
||||||
- :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``.
|
- :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``.
|
||||||
@ -619,7 +619,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
|||||||
**Returning indices**
|
**Returning indices**
|
||||||
|
|
||||||
If you set ``return_index=True``, then ``unique`` returns the indices of the
|
If you set ``return_index=True``, then ``unique`` returns the indices of the
|
||||||
first occurance of each unique value:
|
first occurrence of each unique value:
|
||||||
|
|
||||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||||
>>> values, indices = jnp.unique(x, return_index=True)
|
>>> values, indices = jnp.unique(x, return_index=True)
|
||||||
@ -660,7 +660,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
|||||||
|
|
||||||
**Returning counts**
|
**Returning counts**
|
||||||
|
|
||||||
If you set ``return_counts=True``, then ``unique`` returns the number of occurances
|
If you set ``return_counts=True``, then ``unique`` returns the number of occurrences
|
||||||
within the input for every unique value:
|
within the input for every unique value:
|
||||||
|
|
||||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||||
@ -671,7 +671,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
|||||||
[2 2 1]
|
[2 2 1]
|
||||||
|
|
||||||
For multi-dimensional arrays, this also returns a 1D array of counts
|
For multi-dimensional arrays, this also returns a 1D array of counts
|
||||||
indicating number of occurances along the specified axis:
|
indicating number of occurrences along the specified axis:
|
||||||
|
|
||||||
>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
|
>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
|
||||||
>>> print(values)
|
>>> print(values)
|
||||||
@ -748,13 +748,13 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
|
|||||||
- ``values``:
|
- ``values``:
|
||||||
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||||
- ``indices``:
|
- ``indices``:
|
||||||
An array of shape ``(n_unique,)``. Contains the indices of the first occurance of
|
An array of shape ``(n_unique,)``. Contains the indices of the first occurrence of
|
||||||
each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivlent to ``values``.
|
each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivalent to ``values``.
|
||||||
- ``inverse_indices``:
|
- ``inverse_indices``:
|
||||||
An array of shape ``x.shape``. Contains the indices within ``values`` of each value
|
An array of shape ``x.shape``. Contains the indices within ``values`` of each value
|
||||||
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
|
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
|
||||||
- ``counts``:
|
- ``counts``:
|
||||||
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
|
An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique
|
||||||
value in ``x``.
|
value in ``x``.
|
||||||
|
|
||||||
See also:
|
See also:
|
||||||
@ -770,7 +770,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
|
|||||||
>>> result = jnp.unique_all(x)
|
>>> result = jnp.unique_all(x)
|
||||||
|
|
||||||
The result is a :class:`~typing.NamedTuple` with four named attributes.
|
The result is a :class:`~typing.NamedTuple` with four named attributes.
|
||||||
The ``values`` attribue contains the unique values from the array:
|
The ``values`` attribute contains the unique values from the array:
|
||||||
|
|
||||||
>>> result.values
|
>>> result.values
|
||||||
Array([1, 3, 4], dtype=int32)
|
Array([1, 3, 4], dtype=int32)
|
||||||
@ -829,7 +829,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
|
|||||||
- ``values``:
|
- ``values``:
|
||||||
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||||
- ``counts``:
|
- ``counts``:
|
||||||
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
|
An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique
|
||||||
value in ``x``.
|
value in ``x``.
|
||||||
|
|
||||||
See also:
|
See also:
|
||||||
@ -846,7 +846,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
|
|||||||
>>> result = jnp.unique_counts(x)
|
>>> result = jnp.unique_counts(x)
|
||||||
|
|
||||||
The result is a :class:`~typing.NamedTuple` with two named attributes.
|
The result is a :class:`~typing.NamedTuple` with two named attributes.
|
||||||
The ``values`` attribue contains the unique values from the array:
|
The ``values`` attribute contains the unique values from the array:
|
||||||
|
|
||||||
>>> result.values
|
>>> result.values
|
||||||
Array([1, 3, 4], dtype=int32)
|
Array([1, 3, 4], dtype=int32)
|
||||||
|
@ -354,7 +354,7 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
|
|||||||
"""Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
|
"""Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func : a callable that takes `nin` scalar arguments and return `nout` outputs.
|
func : a callable that takes `nin` scalar arguments and returns `nout` outputs.
|
||||||
nin: integer specifying the number of scalar inputs
|
nin: integer specifying the number of scalar inputs
|
||||||
nout: integer specifying the number of scalar outputs
|
nout: integer specifying the number of scalar outputs
|
||||||
identity: (optional) a scalar specifying the identity of the operation, if any.
|
identity: (optional) a scalar specifying the identity of the operation, if any.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user