From b93da3873b47222c64add5818b14400ebff71ec7 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 17 Jun 2024 13:55:46 +0530 Subject: [PATCH] Fix Typos --- jax/_src/lax/lax.py | 6 +++--- jax/_src/lax/parallel.py | 4 ++-- jax/_src/lax/svd.py | 8 ++++---- jax/_src/numpy/linalg.py | 6 +++--- jax/_src/numpy/setops.py | 26 +++++++++++++------------- jax/_src/numpy/ufunc_api.py | 2 +- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 476cd696f..def8a0c3c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -801,7 +801,7 @@ def ragged_dot( group_sizes: (g,) shaped array with integer element type, where g denotes number of groups. The ith element indicates the size of ith group. precision: Optional. Consistent with precision argument for :func:`jax.lax.dot`. preferred_element_type: Optional. Consistent with precision argument for :func:`jax.lax.dot`. - group_offset: Optional. (1,) shaped array that ndicates the group in group_sizes to start computing from. If not specified, defaults to [0]. + group_offset: Optional. (1,) shaped array that indicates the group in group_sizes to start computing from. If not specified, defaults to [0]. Results: (m, n) shaped array with preferred_element_type element type. @@ -1444,7 +1444,7 @@ def full_like(x: ArrayLike | DuckTypedArray, If not specified, the output will have the same sharding as the input, with a few exceptions/limitations in particular: 1. Sharding is not available during tracing, thus this will rely on jit. - 2. If x is weakly typed or uncomitted, will use default sharding. + 2. If x is weakly typed or uncommitted, will use default sharding. 3. Shape is not None and is different from x.shape, default will be used. Returns: @@ -5116,7 +5116,7 @@ def remaining(original, *removed_lists): def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precision] | None: - """Turns an API precision specification, into a pair of enumeration values. + """Turns an API precision specification into a pair of enumeration values. The API can take the precision as a string, or int, and either as a single value to apply to both operands, or as a sequence of two values. diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index cd1dbd313..47386cb4a 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1540,10 +1540,10 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, When ``False`` (the default value), the size of dimension in ``scatter_dimension`` must match the size of axis ``axis_name`` (or the group size if ``axis_index_groups`` is given). After scattering the - all-reduce result along ``scatter_dimension``, the output is sequeezed by + all-reduce result along ``scatter_dimension``, the output is squeezed by removing ``scatter_dimension``, so the result has lower rank than the input. When ``True``, the size of dimension in ``scatter_dimension`` must - be dividible by the size of axis ``axis_name`` (or the group size if + be divisible by the size of axis ``axis_name`` (or the group size if ``axis_index_groups`` is given), and the ``scatter_dimension`` axis is preserved (so the result has the same rank as the input). diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index 607632ad0..77ff4297e 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -15,7 +15,7 @@ """A JIT-compatible library for QDWH-based singular value decomposition. QDWH is short for QR-based dynamically weighted Halley iteration. The Halley -iteration implemented through QR decmopositions is numerically stable and does +iteration implemented through QR decompositions is numerically stable and does not require solving a linear system involving the iteration matrix or computing its inversion. This is desirable for multicore and heterogeneous computing systems. @@ -59,7 +59,7 @@ def _svd_tall_and_square_input( Args: a: A matrix of shape `m x n` with `m >= n`. hermitian: True if `a` is Hermitian. - compute_uv: Whether to compute also `u` and `v` in addition to `s`. + compute_uv: Whether to also compute `u` and `v` in addition to `s`. max_iterations: The predefined maximum number of iterations of QDWH. Returns: @@ -126,11 +126,11 @@ def svd( full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`, respectively. If False, the shapes are `m x k` and `k x n`, respectively, where `k = min(m, n)`. - compute_uv: Whether to compute also `u` and `v` in addition to `s`. + compute_uv: Whether to also compute `u` and `v` in addition to `s`. hermitian: True if `a` is Hermitian. max_iterations: The predefined maximum number of iterations of QDWH. subset_by_index: Optional 2-tuple [start, end] indicating the range of - indices of singular componenets to compute. For example, if + indices of singular components to compute. For example, if ``subset_by_index`` = [0,2], then ``svd`` computes the two largest singular values (and their singular vectors if `compute_uv` is true. diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e67c447bb..b5970a303 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -795,7 +795,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None, Args: a: array of shape ``(..., M, M)``, containing the Hermitian (if complex) or symmetric (if real) matrix. - UPLO: specifies whether the calculation isdone with the lower triangular + UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. @@ -1249,7 +1249,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: See also: - :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API - - :func:`jax.lax.linalg.qr`: XLA-style QR decompositon API + - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API Examples: Compute the QR decomposition of a matrix: @@ -1443,7 +1443,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): JAX implementation of :func:`numpy.linalg.cross` Args: - x1: N-dimesional array, with ``x1.shape[axis] == 3`` + x1: N-dimensional array, with ``x1.shape[axis] == 3`` x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes broadcast-compatible with ``x1``. axis: axis along which to take the cross product (default: -1). diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 8adbbe8ab..34968b2c7 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -528,7 +528,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal return_index: if True, also return the indices in ``ar`` where each value occurs return_inverse: if True, also return the indices that can be used to reconstruct ``ar`` from the unique values. - return_counts: if True, also return the number of occurances of each unique value. + return_counts: if True, also return the number of occurrences of each unique value. axis: if specified, compute unique values along the specified axis. If None (default), then flatten ``ar`` before computing the unique values. equal_nan: if True, consider NaN values equivalent when determining uniqueness. @@ -546,8 +546,8 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``. - ``unique_index``: *(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains - the indices of the first occurance of each unique value in ``ar``. For 1D inputs, - ``ar[unique_index]`` is equivlent to ``unique_values``. + the indices of the first occurrence of each unique value in ``ar``. For 1D inputs, + ``ar[unique_index]`` is equivalent to ``unique_values``. - ``unique_inverse``: *(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis`` is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified. @@ -555,7 +555,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal ``unique_values[unique_inverse]`` is equivalent to ``ar``. - ``unique_counts``: *(returned only if return_counts is True)* An array of shape ``(n_unique,)``. - Contains the number of occurances of each unique value in ``ar``. + Contains the number of occurrences of each unique value in ``ar``. See also: - :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``. @@ -619,7 +619,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal **Returning indices** If you set ``return_index=True``, then ``unique`` returns the indices of the - first occurance of each unique value: + first occurrence of each unique value: >>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, indices = jnp.unique(x, return_index=True) @@ -660,7 +660,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal **Returning counts** - If you set ``return_counts=True``, then ``unique`` returns the number of occurances + If you set ``return_counts=True``, then ``unique`` returns the number of occurrences within the input for every unique value: >>> x = jnp.array([3, 4, 1, 3, 1]) @@ -671,7 +671,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal [2 2 1] For multi-dimensional arrays, this also returns a 1D array of counts - indicating number of occurances along the specified axis: + indicating number of occurrences along the specified axis: >>> values, counts = jnp.unique(M, axis=0, return_counts=True) >>> print(values) @@ -748,13 +748,13 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, - ``values``: an array of shape ``(n_unique,)`` containing the unique values from ``x``. - ``indices``: - An array of shape ``(n_unique,)``. Contains the indices of the first occurance of - each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivlent to ``values``. + An array of shape ``(n_unique,)``. Contains the indices of the first occurrence of + each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivalent to ``values``. - ``inverse_indices``: An array of shape ``x.shape``. Contains the indices within ``values`` of each value in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``. - ``counts``: - An array of shape ``(n_unique,)``. Contains the number of occurances of each unique + An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique value in ``x``. See also: @@ -770,7 +770,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, >>> result = jnp.unique_all(x) The result is a :class:`~typing.NamedTuple` with four named attributes. - The ``values`` attribue contains the unique values from the array: + The ``values`` attribute contains the unique values from the array: >>> result.values Array([1, 3, 4], dtype=int32) @@ -829,7 +829,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, - ``values``: an array of shape ``(n_unique,)`` containing the unique values from ``x``. - ``counts``: - An array of shape ``(n_unique,)``. Contains the number of occurances of each unique + An array of shape ``(n_unique,)``. Contains the number of occurrences of each unique value in ``x``. See also: @@ -846,7 +846,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, >>> result = jnp.unique_counts(x) The result is a :class:`~typing.NamedTuple` with two named attributes. - The ``values`` attribue contains the unique values from the array: + The ``values`` attribute contains the unique values from the array: >>> result.values Array([1, 3, 4], dtype=int32) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 29f5278bc..7d0769a19 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -354,7 +354,7 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. Args: - func : a callable that takes `nin` scalar arguments and return `nout` outputs. + func : a callable that takes `nin` scalar arguments and returns `nout` outputs. nin: integer specifying the number of scalar inputs nout: integer specifying the number of scalar outputs identity: (optional) a scalar specifying the identity of the operation, if any.