From 4da56dcdd724c2a6c0aa262190afc188883536d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Hr=C4=8Dek?= Date: Wed, 13 Dec 2023 07:45:52 +0100 Subject: [PATCH] Fix duplicate word occurrences --- cloud_tpu_colabs/Pmap_Cookbook.ipynb | 2 +- docs/autodidax.ipynb | 2 +- docs/autodidax.md | 2 +- docs/autodidax.py | 2 +- docs/gpu_performance_tips.md | 2 +- docs/jep/10657-sequencing-effects.md | 2 +- jax/_src/flatten_util.py | 2 +- jax/_src/lax/linalg.py | 4 ++-- jax/_src/lax/parallel.py | 2 +- jax/_src/random.py | 6 +++--- jax/_src/test_util.py | 2 +- jax/experimental/custom_partitioning.py | 2 +- jax/experimental/host_callback.py | 2 +- jax/experimental/jax2tf/examples/saved_model_lib.py | 2 +- jax/experimental/jax2tf/impl_no_xla.py | 2 +- jax/experimental/jax2tf/tests/jax2tf_limitations.py | 2 +- jax/experimental/sparse/csr.py | 2 +- jax/experimental/sparse/linalg.py | 4 ++-- jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 2 +- jaxlib/mosaic/python/apply_vector_layout.py | 2 +- tests/api_test.py | 2 +- tests/lax_numpy_indexing_test.py | 2 +- tests/random_test.py | 2 +- 23 files changed, 27 insertions(+), 27 deletions(-) diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index 832ae49db..cb2bd79ef 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -117,7 +117,7 @@ "id": "iZgTmx5pFd6z" }, "source": [ - "But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel." + "But `pmap` and `vmap` differ in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel." ] }, { diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index af92a7e58..24980cf30 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -1938,7 +1938,7 @@ "source": [ "With any new primitive, we need to give it transformation rules, starting with\n", "its evaluation rule. When we evaluate an application of the `xla_call`\n", - "primitive, we want to stage out out the computation to XLA. That involves\n", + "primitive, we want to stage out the computation to XLA. That involves\n", "translating the jaxpr to an XLA HLO program, transferring the argument values\n", "to the XLA device, executing the XLA program, and transferring back the\n", "results. We'll cache the XLA HLO compilation so that for each `jit`ted\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 42906c22e..1121f0833 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1524,7 +1524,7 @@ xla_call_p = Primitive('xla_call') With any new primitive, we need to give it transformation rules, starting with its evaluation rule. When we evaluate an application of the `xla_call` -primitive, we want to stage out out the computation to XLA. That involves +primitive, we want to stage out the computation to XLA. That involves translating the jaxpr to an XLA HLO program, transferring the argument values to the XLA device, executing the XLA program, and transferring back the results. We'll cache the XLA HLO compilation so that for each `jit`ted diff --git a/docs/autodidax.py b/docs/autodidax.py index 784b28b3a..261574310 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1517,7 +1517,7 @@ xla_call_p = Primitive('xla_call') # With any new primitive, we need to give it transformation rules, starting with # its evaluation rule. When we evaluate an application of the `xla_call` -# primitive, we want to stage out out the computation to XLA. That involves +# primitive, we want to stage out the computation to XLA. That involves # translating the jaxpr to an XLA HLO program, transferring the argument values # to the XLA device, executing the XLA program, and transferring back the # results. We'll cache the XLA HLO compilation so that for each `jit`ted diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 43e17779d..5aa4b0ecb 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -87,7 +87,7 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta These flags tune when to combine multiple small `AllGather`/`ReduceScatter`/`AllReduce` into one big `AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device - communication. For example, for the the `AllGather`/`ReduceScatter` thresholds + communication. For example, for the `AllGather`/`ReduceScatter` thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By default, the `combine_threshold_bytes` is set to 256. diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index 97bde9a1b..12375c873 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -74,7 +74,7 @@ before `"hello"`. The reordering of the print side-effects breaks the illusion of a single-threaded execution model. Another example of where side-effects can "reveal" out-of-order execution is -when we we compile JAX programs. Consider the following JAX code: +when we compile JAX programs. Consider the following JAX code: ```python @jax.jit def f(x): diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index d7c124af2..e18ad1f6e 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -36,7 +36,7 @@ def ravel_pytree(pytree): A pair where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D - vector of the same length back to a pytree of of the same structure as the + vector of the same length back to a pytree of the same structure as the input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of dtype float32 is returned in the first component of the output. diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 8c1a06492..e35afdea4 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -723,7 +723,7 @@ def _eigh_cpu_gpu_lowering( # Therefore, we cannot yet support dynamic non-batch dimensions. if not is_constant_shape(operand_aval.shape[-2:]): raise NotImplementedError( - "Shape polymorphism for for native lowering for eigh is implemented " + "Shape polymorphism for native lowering for eigh is implemented " f"only for the batch dimensions: {operand_aval.shape}") if not (subset_by_index is None or subset_by_index == (0, n)): @@ -760,7 +760,7 @@ def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index): if not is_constant_dim(m): # TODO: maybe we can relax the check below for shape polymorphism? raise NotImplementedError( - "Shape polymorphism for for native lowering for eigh is implemented " + "Shape polymorphism for native lowering for eigh is implemented " f"only for the batch dimensions: {x.shape}") if m <= termination_size and ( subset_by_index is None or subset_by_index == (0, n) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 278e1d070..d941ac663 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -288,7 +288,7 @@ def pshuffle(x, axis_name, perm): x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). - perm: list of of ints encoding sources for the permutation to be applied to + perm: list of ints encoding sources for the permutation to be applied to the axis named ``axis_name``, so that the output at axis index i comes from the input at axis index perm[i]. Every integer in [0, N) should be included exactly once for axis size N. diff --git a/jax/_src/random.py b/jax/_src/random.py index c9761e97d..bf16cd80b 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -1028,7 +1028,7 @@ def dirichlet(key: KeyArrayLike, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Dirichlet random values with given shape and float dtype. - The values are distributed according the probability density function: + The values are distributed according to the probability density function: .. math:: f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1} @@ -1099,7 +1099,7 @@ def exponential(key: KeyArrayLike, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Exponential random values with given shape and float dtype. - The values are distributed according the probability density function: + The values are distributed according to the probability density function: .. math:: f(x) = e^{-x} @@ -1267,7 +1267,7 @@ def gamma(key: KeyArrayLike, dtype: DTypeLikeFloat = float) -> Array: r"""Sample Gamma random values with given shape and float dtype. - The values are distributed according the probability density function: + The values are distributed according to the probability density function: .. math:: f(x;a) \propto x^{a - 1} e^{-x} diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index b8594fcb4..02d1ffd4e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -135,7 +135,7 @@ def with_jax_dtype_defaults(func, use_defaults=True): Args: use_defaults : whether to convert any given output to the default dtype. May be a single boolean, in which case it specifies the conversion for all outputs, - or may be a a pytree with the same structure as the function output. + or may be a pytree with the same structure as the function output. """ @functools.wraps(func) def wrapped(*args, **kwargs): diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 22cdff8c6..cf57ef27d 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -384,7 +384,7 @@ class custom_partitioning: -1.6937828 +0.8402481j 15.999859 -4.0156755j]] Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays. - However, in this case, the HLO of ``my_fft`` does show a a dynamic-slice, since the last dimension + However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension is the dimension along which FFTs are calculated and needs to be replicated on all devices before the computation can be done. diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 856ab266c..40ebe7eea 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1304,7 +1304,7 @@ def _instantiate_zeros(tan, arg): tan: the tangent. arg: the argument for which we need to instantiate the tangent - Returns: tan if is is not ad.Zero, otherwise a 0 array of appropriate type + Returns: tan if it is not ad.Zero, otherwise a 0 array of appropriate type and shape """ if type(tan) is not ad.Zero: diff --git a/jax/experimental/jax2tf/examples/saved_model_lib.py b/jax/experimental/jax2tf/examples/saved_model_lib.py index b3c94b5e1..23add30d7 100644 --- a/jax/experimental/jax2tf/examples/saved_model_lib.py +++ b/jax/experimental/jax2tf/examples/saved_model_lib.py @@ -47,7 +47,7 @@ def convert_and_save_model( """Convert a JAX function and saves a SavedModel. This is an example, we do not promise backwards compatibility for this code. - For serious uses, please copy and and expand it as needed (see note at the top + For serious uses, please copy and expand it as needed (see note at the top of the module). Use this function if you have a trained ML model that has both a prediction diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index b449aeb70..90b26d2c4 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -459,7 +459,7 @@ tf_impl_no_xla[lax.dot_general_p] = _dot_general def _interior_padding(operand, padding_value, padding_config, operand_shape): # Used only when enable_xla=False # Applies only the interior padding from the padding_config. - # We do this somewhat inefficiently, as as a scatter. + # We do this somewhat inefficiently, as a scatter. # For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where # f is the dilation factor for the dimension, i.e., 1 + interior_padding. # Then we compute the cartesian production of the indices (using broadcast diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 1d69d454e..5d90ae311 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -1165,7 +1165,7 @@ class Jax2TfLimitation(test_harnesses.Limitation): @classmethod def select_and_gather_add(cls, harness): return [ - # This JAX primitives is not not exposed directly in the JAX API + # This JAX primitives is not exposed directly in the JAX API # but arises from JVP of `lax.reduce_window` for reducers # `lax.max` or `lax.min`. It also arises from second-order # VJP of the same. Implemented using XlaReduceWindow. diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 3eadad01e..c1178943c 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -559,7 +559,7 @@ def _csr_matmat(data: Array, indices: Array, indptr: Array, B: Array, Returns: C : array of shape ``(shape[1] if transpose else shape[0], cols)`` - representing the matrix-matrix product product. + representing the matrix-matrix product. """ return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose) diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index a23a491ab..c2990d3fe 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -345,7 +345,7 @@ def _svqb(X): tau = jnp.finfo(X.dtype).eps * w[0] padded = jnp.maximum(w, tau) - # Note the the tau == 0 edge case where X was all zeros. + # Note the tau == 0 edge case where X was all zeros. sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5) # X^T X = V diag(w) V^T, so @@ -391,7 +391,7 @@ def _project_out(basis, U): # # Interspersing with orthonormalization isn't directly grounded in the # original analysis, but taken from Algorithm 5 of [3]. In practice, due to - # normalization, I have noticed that that the orthonormalized basis + # normalization, I have noticed that the orthonormalized basis # does not always end up as a subspace of the starting basis in practice. # There may be room to refine this procedure further, but the adjustment # in the subsequent block handles this edge case well enough for now. diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 473cff5b0..797e27006 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3189,7 +3189,7 @@ xla::Array retileToReducedSublanes( xla::Array dst_vreg_array( dst_layout.tileArrayShape(value_shape, target_shape)); - // We need to rotate each src tile in each src vreg once so that that they can + // We need to rotate each src tile in each src vreg once so that they can // be merged to form new vregs. If a src vreg contains more than one src tile, // it will be rotated once per src tile. Consider (8,512) tensor stored with // layout (8,128) in a vreg array of shape (1, 4). Each src vreg diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index c231ac39a..aab4fe8cd 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -966,7 +966,7 @@ def retile_to_reduced_sublanes( dst_layout.tile_array_shape(value_shape), dtype=object ) - # We need to rotate each src tile in each src vreg once so that that they can + # We need to rotate each src tile in each src vreg once so that they can # be merged to form new vregs. If a src vreg contains more than one src tile, # it will be rotated once per src tile. Consider (8,512) tensor stored with # layout (8,128) in a vreg array of shape (1, 4). Each src vreg diff --git a/tests/api_test.py b/tests/api_test.py index 00758757e..05f506abc 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8083,7 +8083,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_nondiff_arg_tracer_error(self): # This is similar to the old (now skipped) test_nondiff_arg_tracer, except - # we're testing for the error message that that usage pattern now raises. + # we're testing for the error message that usage pattern now raises. @partial(jax.custom_vjp, nondiff_argnums=(0,)) def f(x, y): diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 55413edd4..9e326c5f0 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1445,7 +1445,7 @@ class IndexedUpdateTest(jtu.JaxTestCase): expected = jnp.array([5, 2, 3, 3]) self.assertAllClose(ans, expected, check_dtypes=False) - # test with negative segment ids and without without explicit num_segments + # test with negative segment ids and without explicit num_segments # such as num_segments is defined by the smaller index. segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6]) ans = ops.segment_sum(data, segment_ids) diff --git a/tests/random_test.py b/tests/random_test.py index c9a0c150f..88f604f2f 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -593,7 +593,7 @@ class KeyArrayTest(jtu.JaxTestCase): # lax_tests.py as an example. If you add a test here (e.g. testing # lowering of a key-dtyped shaped array), consider whether it # might also be a more general test of opaque element types. If - # so, add a corresponding test to to CustomElementTypesTest as well. + # so, add a corresponding test to CustomElementTypesTest as well. def assertKeysEqual(self, key1, key2): self.assertEqual(key1.dtype, key2.dtype)