Fix Typos

This commit is contained in:
rajasekharporeddy 2024-06-17 13:55:46 +05:30
parent 595a620804
commit b93da3873b
6 changed files with 26 additions and 26 deletions

View File

@ -801,7 +801,7 @@ def ragged_dot(
group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group. group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group.
precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`. precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`. preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`.
group_offset: Optional. (1,) shaped array that ndicates the group in group_sizes to start computing from. If not specified, defaults to [0]. group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0].
Results: Results:
(m, n) shaped array with preferred_element_type element type. (m, n) shaped array with preferred_element_type element type.
@ -1444,7 +1444,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
If not specified, the output will have the same sharding as the input, If not specified, the output will have the same sharding as the input,
with a few exceptions/limitations in particular: with a few exceptions/limitations in particular:
1. Sharding is not available during tracing, thus this will rely on jit. 1. Sharding is not available during tracing, thus this will rely on jit.
2. If x is weakly typed or uncomitted, will use default sharding. 2. If x is weakly typed or uncommitted, will use default sharding.
3. Shape is not None and is different from x.shape, default will be used. 3. Shape is not None and is different from x.shape, default will be used.
Returns: Returns:
@ -5116,7 +5116,7 @@ def remaining(original, *removed_lists):
def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None: def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None:
"""Turns an API precision specification, into a pair of enumeration values. """Turns an API precision specification into a pair of enumeration values.
The API can take the precision as a string, or int, and either as a single The API can take the precision as a string, or int, and either as a single
value to apply to both operands, or as a sequence of two values. value to apply to both operands, or as a sequence of two values.

View File

@ -1540,10 +1540,10 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
When ``False`` (the default value), the size of dimension in When ``False`` (the default value), the size of dimension in
``scatter_dimension`` must match the size of axis ``axis_name`` (or the ``scatter_dimension`` must match the size of axis ``axis_name`` (or the
group size if ``axis_index_groups`` is given). After scattering the group size if ``axis_index_groups`` is given). After scattering the
all-reduce result along ``scatter_dimension``, the output is sequeezed by all-reduce result along ``scatter_dimension``, the output is squeezed by
removing ``scatter_dimension``, so the result has lower rank than the removing ``scatter_dimension``, so the result has lower rank than the
input. When ``True``, the size of dimension in ``scatter_dimension`` must input. When ``True``, the size of dimension in ``scatter_dimension`` must
be dividible by the size of axis ``axis_name`` (or the group size if be divisible by the size of axis ``axis_name`` (or the group size if
``axis_index_groups`` is given), and the ``scatter_dimension`` axis is ``axis_index_groups`` is given), and the ``scatter_dimension`` axis is
preserved (so the result has the same rank as the input). preserved (so the result has the same rank as the input).

View File

@ -15,7 +15,7 @@
"""A JIT-compatible library for QDWH-based singular value decomposition. """A JIT-compatible library for QDWH-based singular value decomposition.
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decmopositions is numerically stable and does iteration implemented through QR decompositions is numerically stable and does
not require solving a linear system involving the iteration matrix or not require solving a linear system involving the iteration matrix or
computing its inversion. This is desirable for multicore and heterogeneous computing its inversion. This is desirable for multicore and heterogeneous
computing systems. computing systems.
@ -59,7 +59,7 @@ def _svd_tall_and_square_input(
Args: Args:
a: A matrix of shape `m x n` with `m >= n`. a: A matrix of shape `m x n` with `m >= n`.
hermitian: True if `a` is Hermitian. hermitian: True if `a` is Hermitian.
compute_uv: Whether to compute also `u` and `v` in addition to `s`. compute_uv: Whether to also compute `u` and `v` in addition to `s`.
max_iterations: The predefined maximum number of iterations of QDWH. max_iterations: The predefined maximum number of iterations of QDWH.
Returns: Returns:
@ -126,11 +126,11 @@ def svd(
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
respectively. If False, the shapes are `m x k` and `k x n`, respectively, respectively. If False, the shapes are `m x k` and `k x n`, respectively,
where `k = min(m, n)`. where `k = min(m, n)`.
compute_uv: Whether to compute also `u` and `v` in addition to `s`. compute_uv: Whether to also compute `u` and `v` in addition to `s`.
hermitian: True if `a` is Hermitian. hermitian: True if `a` is Hermitian.
max_iterations: The predefined maximum number of iterations of QDWH. max_iterations: The predefined maximum number of iterations of QDWH.
subset_by_index: Optional 2-tuple [start, end] indicating the range of subset_by_index: Optional 2-tuple [start, end] indicating the range of
indices of singular componenets to compute. For example, if indices of singular components to compute. For example, if
``subset_by_index`` = [0,2], then ``svd`` computes the two largest ``subset_by_index`` = [0,2], then ``svd`` computes the two largest
singular values (and their singular vectors if `compute_uv` is true. singular values (and their singular vectors if `compute_uv` is true.

View File

@ -795,7 +795,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
Args: Args:
a: array of shape ``(..., M, M)``, containing the Hermitian (if complex) a: array of shape ``(..., M, M)``, containing the Hermitian (if complex)
or symmetric (if real) matrix. or symmetric (if real) matrix.
UPLO: specifies whether the calculation isdone with the lower triangular UPLO: specifies whether the calculation is done with the lower triangular
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
symmetrize_input: if True (default) then input is symmetrized, which leads symmetrize_input: if True (default) then input is symmetrized, which leads
to better behavior under automatic differentiation. to better behavior under automatic differentiation.
@ -1249,7 +1249,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
See also: See also:
- :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API - :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API
- :func:`jax.lax.linalg.qr`: XLA-style QR decompositon API - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API
Examples: Examples:
Compute the QR decomposition of a matrix: Compute the QR decomposition of a matrix:
@ -1443,7 +1443,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
JAX implementation of :func:`numpy.linalg.cross` JAX implementation of :func:`numpy.linalg.cross`
Args: Args:
x1: N-dimesional array, with ``x1.shape[axis] == 3`` x1: N-dimensional array, with ``x1.shape[axis] == 3``
x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes
broadcast-compatible with ``x1``. broadcast-compatible with ``x1``.
axis: axis along which to take the cross product (default: -1). axis: axis along which to take the cross product (default: -1).

View File

@ -528,7 +528,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
return_index: if True, also return the indices in ``ar`` where each value occurs return_index: if True, also return the indices in ``ar`` where each value occurs
return_inverse: if True, also return the indices that can be used to reconstruct return_inverse: if True, also return the indices that can be used to reconstruct
``ar`` from the unique values. ``ar`` from the unique values.
return_counts: if True, also return the number of occurances of each unique value. return_counts: if True, also return the number of occurrences of each unique value.
axis: if specified, compute unique values along the specified axis. If None (default), axis: if specified, compute unique values along the specified axis. If None (default),
then flatten ``ar`` before computing the unique values. then flatten ``ar`` before computing the unique values.
equal_nan: if True, consider NaN values equivalent when determining uniqueness. equal_nan: if True, consider NaN values equivalent when determining uniqueness.
@ -546,8 +546,8 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``. specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``.
- ``unique_index``: - ``unique_index``:
*(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains *(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains
the indices of the first occurance of each unique value in ``ar``. For 1D inputs, the indices of the first occurrence of each unique value in ``ar``. For 1D inputs,
``ar[unique_index]`` is equivlent to ``unique_values``. ``ar[unique_index]`` is equivalent to ``unique_values``.
- ``unique_inverse``: - ``unique_inverse``:
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis`` *(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified. is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
@ -555,7 +555,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
``unique_values[unique_inverse]`` is equivalent to ``ar``. ``unique_values[unique_inverse]`` is equivalent to ``ar``.
- ``unique_counts``: - ``unique_counts``:
*(returned only if return_counts is True)* An array of shape ``(n_unique,)``. *(returned only if return_counts is True)* An array of shape ``(n_unique,)``.
Contains the number of occurances of each unique value in ``ar``. Contains the number of occurrences of each unique value in ``ar``.
See also: See also:
- :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``. - :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``.
@ -619,7 +619,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
**Returning indices** **Returning indices**
If you set ``return_index=True``, then ``unique`` returns the indices of the If you set ``return_index=True``, then ``unique`` returns the indices of the
first occurance of each unique value: first occurrence of each unique value:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, indices = jnp.unique(x, return_index=True) >>> values, indices = jnp.unique(x, return_index=True)
@ -660,7 +660,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
**Returning counts** **Returning counts**
If you set ``return_counts=True``, then ``unique`` returns the number of occurances If you set ``return_counts=True``, then ``unique`` returns the number of occurrences
within the input for every unique value: within the input for every unique value:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> x = jnp.array([3, 4, 1, 3, 1])
@ -671,7 +671,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
[2 2 1] [2 2 1]
For multi-dimensional arrays, this also returns a 1D array of counts For multi-dimensional arrays, this also returns a 1D array of counts
indicating number of occurances along the specified axis: indicating number of occurrences along the specified axis:
>>> values, counts = jnp.unique(M, axis=0, return_counts=True) >>> values, counts = jnp.unique(M, axis=0, return_counts=True)
>>> print(values) >>> print(values)
@ -748,13 +748,13 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
- ``values``: - ``values``:
an array of shape ``(n_unique,)`` containing the unique values from ``x``. an array of shape ``(n_unique,)`` containing the unique values from ``x``.
- ``indices``: - ``indices``:
An array of shape ``(n_unique,)``. Contains the indices of the first occurance of An array of shape ``(n_unique,)``. Contains the indices of the first occurrence of
each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivlent to ``values``. each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivalent to ``values``.
- ``inverse_indices``: - ``inverse_indices``:
An array of shape ``x.shape``. Contains the indices within ``values`` of each value An array of shape ``x.shape``. Contains the indices within ``values`` of each value
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``. in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
- ``counts``: - ``counts``:
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique
value in ``x``. value in ``x``.
See also: See also:
@ -770,7 +770,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None,
>>> result = jnp.unique_all(x) >>> result = jnp.unique_all(x)
The result is a :class:`~typing.NamedTuple` with four named attributes. The result is a :class:`~typing.NamedTuple` with four named attributes.
The ``values`` attribue contains the unique values from the array: The ``values`` attribute contains the unique values from the array:
>>> result.values >>> result.values
Array([1, 3, 4], dtype=int32) Array([1, 3, 4], dtype=int32)
@ -829,7 +829,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
- ``values``: - ``values``:
an array of shape ``(n_unique,)`` containing the unique values from ``x``. an array of shape ``(n_unique,)`` containing the unique values from ``x``.
- ``counts``: - ``counts``:
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique
value in ``x``. value in ``x``.
See also: See also:
@ -846,7 +846,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None,
>>> result = jnp.unique_counts(x) >>> result = jnp.unique_counts(x)
The result is a :class:`~typing.NamedTuple` with two named attributes. The result is a :class:`~typing.NamedTuple` with two named attributes.
The ``values`` attribue contains the unique values from the array: The ``values`` attribute contains the unique values from the array:
>>> result.values >>> result.values
Array([1, 3, 4], dtype=int32) Array([1, 3, 4], dtype=int32)

View File

@ -354,7 +354,7 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
"""Create a JAX ufunc from an arbitrary JAX-compatible scalar function. """Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
Args: Args:
func : a callable that takes `nin` scalar arguments and return `nout` outputs. func : a callable that takes `nin` scalar arguments and returns `nout` outputs.
nin: integer specifying the number of scalar inputs nin: integer specifying the number of scalar inputs
nout: integer specifying the number of scalar outputs nout: integer specifying the number of scalar outputs
identity: (optional) a scalar specifying the identity of the operation, if any. identity: (optional) a scalar specifying the identity of the operation, if any.