Fix duplicate word occurrences

This commit is contained in:
Jan Hrček 2023-12-13 07:45:52 +01:00
parent e82807297b
commit 4da56dcdd7
23 changed files with 27 additions and 27 deletions

View File

@ -117,7 +117,7 @@
"id": "iZgTmx5pFd6z"
},
"source": [
"But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
"But `pmap` and `vmap` differ in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
]
},
{

View File

@ -1938,7 +1938,7 @@
"source": [
"With any new primitive, we need to give it transformation rules, starting with\n",
"its evaluation rule. When we evaluate an application of the `xla_call`\n",
"primitive, we want to stage out out the computation to XLA. That involves\n",
"primitive, we want to stage out the computation to XLA. That involves\n",
"translating the jaxpr to an XLA HLO program, transferring the argument values\n",
"to the XLA device, executing the XLA program, and transferring back the\n",
"results. We'll cache the XLA HLO compilation so that for each `jit`ted\n",

View File

@ -1524,7 +1524,7 @@ xla_call_p = Primitive('xla_call')
With any new primitive, we need to give it transformation rules, starting with
its evaluation rule. When we evaluate an application of the `xla_call`
primitive, we want to stage out out the computation to XLA. That involves
primitive, we want to stage out the computation to XLA. That involves
translating the jaxpr to an XLA HLO program, transferring the argument values
to the XLA device, executing the XLA program, and transferring back the
results. We'll cache the XLA HLO compilation so that for each `jit`ted

View File

@ -1517,7 +1517,7 @@ xla_call_p = Primitive('xla_call')
# With any new primitive, we need to give it transformation rules, starting with
# its evaluation rule. When we evaluate an application of the `xla_call`
# primitive, we want to stage out out the computation to XLA. That involves
# primitive, we want to stage out the computation to XLA. That involves
# translating the jaxpr to an XLA HLO program, transferring the argument values
# to the XLA device, executing the XLA program, and transferring back the
# results. We'll cache the XLA HLO compilation so that for each `jit`ted

View File

@ -87,7 +87,7 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
These flags tune when to combine multiple small
`AllGather`/`ReduceScatter`/`AllReduce` into one big
`AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device
communication. For example, for the the `AllGather`/`ReduceScatter` thresholds
communication. For example, for the `AllGather`/`ReduceScatter` thresholds
on a Transformer-based workload, consider tuning them high enough so as to
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
default, the `combine_threshold_bytes` is set to 256.

View File

@ -74,7 +74,7 @@ before `"hello"`. The reordering of the print side-effects breaks the illusion
of a single-threaded execution model.
Another example of where side-effects can "reveal" out-of-order execution is
when we we compile JAX programs. Consider the following JAX code:
when we compile JAX programs. Consider the following JAX code:
```python
@jax.jit
def f(x):

View File

@ -36,7 +36,7 @@ def ravel_pytree(pytree):
A pair where the first element is a 1D array representing the flattened and
concatenated leaf values, with dtype determined by promoting the dtypes of
leaf values, and the second element is a callable for unflattening a 1D
vector of the same length back to a pytree of of the same structure as the
vector of the same length back to a pytree of the same structure as the
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
a convention a 1D empty array of dtype float32 is returned in the first
component of the output.

View File

@ -723,7 +723,7 @@ def _eigh_cpu_gpu_lowering(
# Therefore, we cannot yet support dynamic non-batch dimensions.
if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for for native lowering for eigh is implemented "
"Shape polymorphism for native lowering for eigh is implemented "
f"only for the batch dimensions: {operand_aval.shape}")
if not (subset_by_index is None or subset_by_index == (0, n)):
@ -760,7 +760,7 @@ def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index):
if not is_constant_dim(m):
# TODO: maybe we can relax the check below for shape polymorphism?
raise NotImplementedError(
"Shape polymorphism for for native lowering for eigh is implemented "
"Shape polymorphism for native lowering for eigh is implemented "
f"only for the batch dimensions: {x.shape}")
if m <= termination_size and (
subset_by_index is None or subset_by_index == (0, n)

View File

@ -288,7 +288,7 @@ def pshuffle(x, axis_name, perm):
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
perm: list of of ints encoding sources for the permutation to be applied to
perm: list of ints encoding sources for the permutation to be applied to
the axis named ``axis_name``, so that the output at axis index i
comes from the input at axis index perm[i]. Every integer in [0, N) should
be included exactly once for axis size N.

View File

@ -1028,7 +1028,7 @@ def dirichlet(key: KeyArrayLike,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Dirichlet random values with given shape and float dtype.
The values are distributed according the probability density function:
The values are distributed according to the probability density function:
.. math::
f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1}
@ -1099,7 +1099,7 @@ def exponential(key: KeyArrayLike,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Exponential random values with given shape and float dtype.
The values are distributed according the probability density function:
The values are distributed according to the probability density function:
.. math::
f(x) = e^{-x}
@ -1267,7 +1267,7 @@ def gamma(key: KeyArrayLike,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Gamma random values with given shape and float dtype.
The values are distributed according the probability density function:
The values are distributed according to the probability density function:
.. math::
f(x;a) \propto x^{a - 1} e^{-x}

View File

@ -135,7 +135,7 @@ def with_jax_dtype_defaults(func, use_defaults=True):
Args:
use_defaults : whether to convert any given output to the default dtype. May be
a single boolean, in which case it specifies the conversion for all outputs,
or may be a a pytree with the same structure as the function output.
or may be a pytree with the same structure as the function output.
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):

View File

@ -384,7 +384,7 @@ class custom_partitioning:
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
However, in this case, the HLO of ``my_fft`` does show a a dynamic-slice, since the last dimension
However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension
is the dimension along which FFTs are calculated and needs to be replicated on all devices before
the computation can be done.

View File

@ -1304,7 +1304,7 @@ def _instantiate_zeros(tan, arg):
tan: the tangent.
arg: the argument for which we need to instantiate the tangent
Returns: tan if is is not ad.Zero, otherwise a 0 array of appropriate type
Returns: tan if it is not ad.Zero, otherwise a 0 array of appropriate type
and shape
"""
if type(tan) is not ad.Zero:

View File

@ -47,7 +47,7 @@ def convert_and_save_model(
"""Convert a JAX function and saves a SavedModel.
This is an example, we do not promise backwards compatibility for this code.
For serious uses, please copy and and expand it as needed (see note at the top
For serious uses, please copy and expand it as needed (see note at the top
of the module).
Use this function if you have a trained ML model that has both a prediction

View File

@ -459,7 +459,7 @@ tf_impl_no_xla[lax.dot_general_p] = _dot_general
def _interior_padding(operand, padding_value, padding_config, operand_shape):
# Used only when enable_xla=False
# Applies only the interior padding from the padding_config.
# We do this somewhat inefficiently, as as a scatter.
# We do this somewhat inefficiently, as a scatter.
# For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where
# f is the dilation factor for the dimension, i.e., 1 + interior_padding.
# Then we compute the cartesian production of the indices (using broadcast

View File

@ -1165,7 +1165,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
@classmethod
def select_and_gather_add(cls, harness):
return [
# This JAX primitives is not not exposed directly in the JAX API
# This JAX primitives is not exposed directly in the JAX API
# but arises from JVP of `lax.reduce_window` for reducers
# `lax.max` or `lax.min`. It also arises from second-order
# VJP of the same. Implemented using XlaReduceWindow.

View File

@ -559,7 +559,7 @@ def _csr_matmat(data: Array, indices: Array, indptr: Array, B: Array,
Returns:
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product product.
representing the matrix-matrix product.
"""
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)

View File

@ -345,7 +345,7 @@ def _svqb(X):
tau = jnp.finfo(X.dtype).eps * w[0]
padded = jnp.maximum(w, tau)
# Note the the tau == 0 edge case where X was all zeros.
# Note the tau == 0 edge case where X was all zeros.
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5)
# X^T X = V diag(w) V^T, so
@ -391,7 +391,7 @@ def _project_out(basis, U):
#
# Interspersing with orthonormalization isn't directly grounded in the
# original analysis, but taken from Algorithm 5 of [3]. In practice, due to
# normalization, I have noticed that that the orthonormalized basis
# normalization, I have noticed that the orthonormalized basis
# does not always end up as a subspace of the starting basis in practice.
# There may be room to refine this procedure further, but the adjustment
# in the subsequent block handles this edge case well enough for now.

View File

@ -3189,7 +3189,7 @@ xla::Array<Value> retileToReducedSublanes(
xla::Array<Value> dst_vreg_array(
dst_layout.tileArrayShape(value_shape, target_shape));
// We need to rotate each src tile in each src vreg once so that that they can
// We need to rotate each src tile in each src vreg once so that they can
// be merged to form new vregs. If a src vreg contains more than one src tile,
// it will be rotated once per src tile. Consider (8,512) tensor stored with
// layout (8,128) in a vreg array of shape (1, 4). Each src vreg

View File

@ -966,7 +966,7 @@ def retile_to_reduced_sublanes(
dst_layout.tile_array_shape(value_shape), dtype=object
)
# We need to rotate each src tile in each src vreg once so that that they can
# We need to rotate each src tile in each src vreg once so that they can
# be merged to form new vregs. If a src vreg contains more than one src tile,
# it will be rotated once per src tile. Consider (8,512) tensor stored with
# layout (8,128) in a vreg array of shape (1, 4). Each src vreg

View File

@ -8083,7 +8083,7 @@ class CustomVJPTest(jtu.JaxTestCase):
def test_nondiff_arg_tracer_error(self):
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
# we're testing for the error message that that usage pattern now raises.
# we're testing for the error message that usage pattern now raises.
@partial(jax.custom_vjp, nondiff_argnums=(0,))
def f(x, y):

View File

@ -1445,7 +1445,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
expected = jnp.array([5, 2, 3, 3])
self.assertAllClose(ans, expected, check_dtypes=False)
# test with negative segment ids and without without explicit num_segments
# test with negative segment ids and without explicit num_segments
# such as num_segments is defined by the smaller index.
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
ans = ops.segment_sum(data, segment_ids)

View File

@ -593,7 +593,7 @@ class KeyArrayTest(jtu.JaxTestCase):
# lax_tests.py as an example. If you add a test here (e.g. testing
# lowering of a key-dtyped shaped array), consider whether it
# might also be a more general test of opaque element types. If
# so, add a corresponding test to to CustomElementTypesTest as well.
# so, add a corresponding test to CustomElementTypesTest as well.
def assertKeysEqual(self, key1, key2):
self.assertEqual(key1.dtype, key2.dtype)