Fix duplicate word occurrences

This commit is contained in:
Jan Hrček 2023-12-13 07:45:52 +01:00
parent e82807297b
commit 4da56dcdd7
23 changed files with 27 additions and 27 deletions

View File

@ -117,7 +117,7 @@
"id": "iZgTmx5pFd6z" "id": "iZgTmx5pFd6z"
}, },
"source": [ "source": [
"But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel." "But `pmap` and `vmap` differ in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
] ]
}, },
{ {

View File

@ -1938,7 +1938,7 @@
"source": [ "source": [
"With any new primitive, we need to give it transformation rules, starting with\n", "With any new primitive, we need to give it transformation rules, starting with\n",
"its evaluation rule. When we evaluate an application of the `xla_call`\n", "its evaluation rule. When we evaluate an application of the `xla_call`\n",
"primitive, we want to stage out out the computation to XLA. That involves\n", "primitive, we want to stage out the computation to XLA. That involves\n",
"translating the jaxpr to an XLA HLO program, transferring the argument values\n", "translating the jaxpr to an XLA HLO program, transferring the argument values\n",
"to the XLA device, executing the XLA program, and transferring back the\n", "to the XLA device, executing the XLA program, and transferring back the\n",
"results. We'll cache the XLA HLO compilation so that for each `jit`ted\n", "results. We'll cache the XLA HLO compilation so that for each `jit`ted\n",

View File

@ -1524,7 +1524,7 @@ xla_call_p = Primitive('xla_call')
With any new primitive, we need to give it transformation rules, starting with With any new primitive, we need to give it transformation rules, starting with
its evaluation rule. When we evaluate an application of the `xla_call` its evaluation rule. When we evaluate an application of the `xla_call`
primitive, we want to stage out out the computation to XLA. That involves primitive, we want to stage out the computation to XLA. That involves
translating the jaxpr to an XLA HLO program, transferring the argument values translating the jaxpr to an XLA HLO program, transferring the argument values
to the XLA device, executing the XLA program, and transferring back the to the XLA device, executing the XLA program, and transferring back the
results. We'll cache the XLA HLO compilation so that for each `jit`ted results. We'll cache the XLA HLO compilation so that for each `jit`ted

View File

@ -1517,7 +1517,7 @@ xla_call_p = Primitive('xla_call')
# With any new primitive, we need to give it transformation rules, starting with # With any new primitive, we need to give it transformation rules, starting with
# its evaluation rule. When we evaluate an application of the `xla_call` # its evaluation rule. When we evaluate an application of the `xla_call`
# primitive, we want to stage out out the computation to XLA. That involves # primitive, we want to stage out the computation to XLA. That involves
# translating the jaxpr to an XLA HLO program, transferring the argument values # translating the jaxpr to an XLA HLO program, transferring the argument values
# to the XLA device, executing the XLA program, and transferring back the # to the XLA device, executing the XLA program, and transferring back the
# results. We'll cache the XLA HLO compilation so that for each `jit`ted # results. We'll cache the XLA HLO compilation so that for each `jit`ted

View File

@ -87,7 +87,7 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
These flags tune when to combine multiple small These flags tune when to combine multiple small
`AllGather`/`ReduceScatter`/`AllReduce` into one big `AllGather`/`ReduceScatter`/`AllReduce` into one big
`AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device `AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device
communication. For example, for the the `AllGather`/`ReduceScatter` thresholds communication. For example, for the `AllGather`/`ReduceScatter` thresholds
on a Transformer-based workload, consider tuning them high enough so as to on a Transformer-based workload, consider tuning them high enough so as to
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
default, the `combine_threshold_bytes` is set to 256. default, the `combine_threshold_bytes` is set to 256.

View File

@ -74,7 +74,7 @@ before `"hello"`. The reordering of the print side-effects breaks the illusion
of a single-threaded execution model. of a single-threaded execution model.
Another example of where side-effects can "reveal" out-of-order execution is Another example of where side-effects can "reveal" out-of-order execution is
when we we compile JAX programs. Consider the following JAX code: when we compile JAX programs. Consider the following JAX code:
```python ```python
@jax.jit @jax.jit
def f(x): def f(x):

View File

@ -36,7 +36,7 @@ def ravel_pytree(pytree):
A pair where the first element is a 1D array representing the flattened and A pair where the first element is a 1D array representing the flattened and
concatenated leaf values, with dtype determined by promoting the dtypes of concatenated leaf values, with dtype determined by promoting the dtypes of
leaf values, and the second element is a callable for unflattening a 1D leaf values, and the second element is a callable for unflattening a 1D
vector of the same length back to a pytree of of the same structure as the vector of the same length back to a pytree of the same structure as the
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
a convention a 1D empty array of dtype float32 is returned in the first a convention a 1D empty array of dtype float32 is returned in the first
component of the output. component of the output.

View File

@ -723,7 +723,7 @@ def _eigh_cpu_gpu_lowering(
# Therefore, we cannot yet support dynamic non-batch dimensions. # Therefore, we cannot yet support dynamic non-batch dimensions.
if not is_constant_shape(operand_aval.shape[-2:]): if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError( raise NotImplementedError(
"Shape polymorphism for for native lowering for eigh is implemented " "Shape polymorphism for native lowering for eigh is implemented "
f"only for the batch dimensions: {operand_aval.shape}") f"only for the batch dimensions: {operand_aval.shape}")
if not (subset_by_index is None or subset_by_index == (0, n)): if not (subset_by_index is None or subset_by_index == (0, n)):
@ -760,7 +760,7 @@ def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index):
if not is_constant_dim(m): if not is_constant_dim(m):
# TODO: maybe we can relax the check below for shape polymorphism? # TODO: maybe we can relax the check below for shape polymorphism?
raise NotImplementedError( raise NotImplementedError(
"Shape polymorphism for for native lowering for eigh is implemented " "Shape polymorphism for native lowering for eigh is implemented "
f"only for the batch dimensions: {x.shape}") f"only for the batch dimensions: {x.shape}")
if m <= termination_size and ( if m <= termination_size and (
subset_by_index is None or subset_by_index == (0, n) subset_by_index is None or subset_by_index == (0, n)

View File

@ -288,7 +288,7 @@ def pshuffle(x, axis_name, perm):
x: array(s) with a mapped axis named ``axis_name``. x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details). :func:`jax.pmap` documentation for more details).
perm: list of of ints encoding sources for the permutation to be applied to perm: list of ints encoding sources for the permutation to be applied to
the axis named ``axis_name``, so that the output at axis index i the axis named ``axis_name``, so that the output at axis index i
comes from the input at axis index perm[i]. Every integer in [0, N) should comes from the input at axis index perm[i]. Every integer in [0, N) should
be included exactly once for axis size N. be included exactly once for axis size N.

View File

@ -1028,7 +1028,7 @@ def dirichlet(key: KeyArrayLike,
dtype: DTypeLikeFloat = float) -> Array: dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Dirichlet random values with given shape and float dtype. r"""Sample Dirichlet random values with given shape and float dtype.
The values are distributed according the probability density function: The values are distributed according to the probability density function:
.. math:: .. math::
f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1} f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1}
@ -1099,7 +1099,7 @@ def exponential(key: KeyArrayLike,
dtype: DTypeLikeFloat = float) -> Array: dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Exponential random values with given shape and float dtype. r"""Sample Exponential random values with given shape and float dtype.
The values are distributed according the probability density function: The values are distributed according to the probability density function:
.. math:: .. math::
f(x) = e^{-x} f(x) = e^{-x}
@ -1267,7 +1267,7 @@ def gamma(key: KeyArrayLike,
dtype: DTypeLikeFloat = float) -> Array: dtype: DTypeLikeFloat = float) -> Array:
r"""Sample Gamma random values with given shape and float dtype. r"""Sample Gamma random values with given shape and float dtype.
The values are distributed according the probability density function: The values are distributed according to the probability density function:
.. math:: .. math::
f(x;a) \propto x^{a - 1} e^{-x} f(x;a) \propto x^{a - 1} e^{-x}

View File

@ -135,7 +135,7 @@ def with_jax_dtype_defaults(func, use_defaults=True):
Args: Args:
use_defaults : whether to convert any given output to the default dtype. May be use_defaults : whether to convert any given output to the default dtype. May be
a single boolean, in which case it specifies the conversion for all outputs, a single boolean, in which case it specifies the conversion for all outputs,
or may be a a pytree with the same structure as the function output. or may be a pytree with the same structure as the function output.
""" """
@functools.wraps(func) @functools.wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):

View File

@ -384,7 +384,7 @@ class custom_partitioning:
-1.6937828 +0.8402481j 15.999859 -4.0156755j]] -1.6937828 +0.8402481j 15.999859 -4.0156755j]]
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays. Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
However, in this case, the HLO of ``my_fft`` does show a a dynamic-slice, since the last dimension However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension
is the dimension along which FFTs are calculated and needs to be replicated on all devices before is the dimension along which FFTs are calculated and needs to be replicated on all devices before
the computation can be done. the computation can be done.

View File

@ -1304,7 +1304,7 @@ def _instantiate_zeros(tan, arg):
tan: the tangent. tan: the tangent.
arg: the argument for which we need to instantiate the tangent arg: the argument for which we need to instantiate the tangent
Returns: tan if is is not ad.Zero, otherwise a 0 array of appropriate type Returns: tan if it is not ad.Zero, otherwise a 0 array of appropriate type
and shape and shape
""" """
if type(tan) is not ad.Zero: if type(tan) is not ad.Zero:

View File

@ -47,7 +47,7 @@ def convert_and_save_model(
"""Convert a JAX function and saves a SavedModel. """Convert a JAX function and saves a SavedModel.
This is an example, we do not promise backwards compatibility for this code. This is an example, we do not promise backwards compatibility for this code.
For serious uses, please copy and and expand it as needed (see note at the top For serious uses, please copy and expand it as needed (see note at the top
of the module). of the module).
Use this function if you have a trained ML model that has both a prediction Use this function if you have a trained ML model that has both a prediction

View File

@ -459,7 +459,7 @@ tf_impl_no_xla[lax.dot_general_p] = _dot_general
def _interior_padding(operand, padding_value, padding_config, operand_shape): def _interior_padding(operand, padding_value, padding_config, operand_shape):
# Used only when enable_xla=False # Used only when enable_xla=False
# Applies only the interior padding from the padding_config. # Applies only the interior padding from the padding_config.
# We do this somewhat inefficiently, as as a scatter. # We do this somewhat inefficiently, as a scatter.
# For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where # For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where
# f is the dilation factor for the dimension, i.e., 1 + interior_padding. # f is the dilation factor for the dimension, i.e., 1 + interior_padding.
# Then we compute the cartesian production of the indices (using broadcast # Then we compute the cartesian production of the indices (using broadcast

View File

@ -1165,7 +1165,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
@classmethod @classmethod
def select_and_gather_add(cls, harness): def select_and_gather_add(cls, harness):
return [ return [
# This JAX primitives is not not exposed directly in the JAX API # This JAX primitives is not exposed directly in the JAX API
# but arises from JVP of `lax.reduce_window` for reducers # but arises from JVP of `lax.reduce_window` for reducers
# `lax.max` or `lax.min`. It also arises from second-order # `lax.max` or `lax.min`. It also arises from second-order
# VJP of the same. Implemented using XlaReduceWindow. # VJP of the same. Implemented using XlaReduceWindow.

View File

@ -559,7 +559,7 @@ def _csr_matmat(data: Array, indices: Array, indptr: Array, B: Array,
Returns: Returns:
C : array of shape ``(shape[1] if transpose else shape[0], cols)`` C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product product. representing the matrix-matrix product.
""" """
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose) return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)

View File

@ -345,7 +345,7 @@ def _svqb(X):
tau = jnp.finfo(X.dtype).eps * w[0] tau = jnp.finfo(X.dtype).eps * w[0]
padded = jnp.maximum(w, tau) padded = jnp.maximum(w, tau)
# Note the the tau == 0 edge case where X was all zeros. # Note the tau == 0 edge case where X was all zeros.
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5) sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5)
# X^T X = V diag(w) V^T, so # X^T X = V diag(w) V^T, so
@ -391,7 +391,7 @@ def _project_out(basis, U):
# #
# Interspersing with orthonormalization isn't directly grounded in the # Interspersing with orthonormalization isn't directly grounded in the
# original analysis, but taken from Algorithm 5 of [3]. In practice, due to # original analysis, but taken from Algorithm 5 of [3]. In practice, due to
# normalization, I have noticed that that the orthonormalized basis # normalization, I have noticed that the orthonormalized basis
# does not always end up as a subspace of the starting basis in practice. # does not always end up as a subspace of the starting basis in practice.
# There may be room to refine this procedure further, but the adjustment # There may be room to refine this procedure further, but the adjustment
# in the subsequent block handles this edge case well enough for now. # in the subsequent block handles this edge case well enough for now.

View File

@ -3189,7 +3189,7 @@ xla::Array<Value> retileToReducedSublanes(
xla::Array<Value> dst_vreg_array( xla::Array<Value> dst_vreg_array(
dst_layout.tileArrayShape(value_shape, target_shape)); dst_layout.tileArrayShape(value_shape, target_shape));
// We need to rotate each src tile in each src vreg once so that that they can // We need to rotate each src tile in each src vreg once so that they can
// be merged to form new vregs. If a src vreg contains more than one src tile, // be merged to form new vregs. If a src vreg contains more than one src tile,
// it will be rotated once per src tile. Consider (8,512) tensor stored with // it will be rotated once per src tile. Consider (8,512) tensor stored with
// layout (8,128) in a vreg array of shape (1, 4). Each src vreg // layout (8,128) in a vreg array of shape (1, 4). Each src vreg

View File

@ -966,7 +966,7 @@ def retile_to_reduced_sublanes(
dst_layout.tile_array_shape(value_shape), dtype=object dst_layout.tile_array_shape(value_shape), dtype=object
) )
# We need to rotate each src tile in each src vreg once so that that they can # We need to rotate each src tile in each src vreg once so that they can
# be merged to form new vregs. If a src vreg contains more than one src tile, # be merged to form new vregs. If a src vreg contains more than one src tile,
# it will be rotated once per src tile. Consider (8,512) tensor stored with # it will be rotated once per src tile. Consider (8,512) tensor stored with
# layout (8,128) in a vreg array of shape (1, 4). Each src vreg # layout (8,128) in a vreg array of shape (1, 4). Each src vreg

View File

@ -8083,7 +8083,7 @@ class CustomVJPTest(jtu.JaxTestCase):
def test_nondiff_arg_tracer_error(self): def test_nondiff_arg_tracer_error(self):
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except # This is similar to the old (now skipped) test_nondiff_arg_tracer, except
# we're testing for the error message that that usage pattern now raises. # we're testing for the error message that usage pattern now raises.
@partial(jax.custom_vjp, nondiff_argnums=(0,)) @partial(jax.custom_vjp, nondiff_argnums=(0,))
def f(x, y): def f(x, y):

View File

@ -1445,7 +1445,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
expected = jnp.array([5, 2, 3, 3]) expected = jnp.array([5, 2, 3, 3])
self.assertAllClose(ans, expected, check_dtypes=False) self.assertAllClose(ans, expected, check_dtypes=False)
# test with negative segment ids and without without explicit num_segments # test with negative segment ids and without explicit num_segments
# such as num_segments is defined by the smaller index. # such as num_segments is defined by the smaller index.
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6]) segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
ans = ops.segment_sum(data, segment_ids) ans = ops.segment_sum(data, segment_ids)

View File

@ -593,7 +593,7 @@ class KeyArrayTest(jtu.JaxTestCase):
# lax_tests.py as an example. If you add a test here (e.g. testing # lax_tests.py as an example. If you add a test here (e.g. testing
# lowering of a key-dtyped shaped array), consider whether it # lowering of a key-dtyped shaped array), consider whether it
# might also be a more general test of opaque element types. If # might also be a more general test of opaque element types. If
# so, add a corresponding test to to CustomElementTypesTest as well. # so, add a corresponding test to CustomElementTypesTest as well.
def assertKeysEqual(self, key1, key2): def assertKeysEqual(self, key1, key2):
self.assertEqual(key1.dtype, key2.dtype) self.assertEqual(key1.dtype, key2.dtype)