mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix duplicate word occurrences
This commit is contained in:
parent
e82807297b
commit
4da56dcdd7
@ -117,7 +117,7 @@
|
||||
"id": "iZgTmx5pFd6z"
|
||||
},
|
||||
"source": [
|
||||
"But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
|
||||
"But `pmap` and `vmap` differ in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1938,7 +1938,7 @@
|
||||
"source": [
|
||||
"With any new primitive, we need to give it transformation rules, starting with\n",
|
||||
"its evaluation rule. When we evaluate an application of the `xla_call`\n",
|
||||
"primitive, we want to stage out out the computation to XLA. That involves\n",
|
||||
"primitive, we want to stage out the computation to XLA. That involves\n",
|
||||
"translating the jaxpr to an XLA HLO program, transferring the argument values\n",
|
||||
"to the XLA device, executing the XLA program, and transferring back the\n",
|
||||
"results. We'll cache the XLA HLO compilation so that for each `jit`ted\n",
|
||||
|
@ -1524,7 +1524,7 @@ xla_call_p = Primitive('xla_call')
|
||||
|
||||
With any new primitive, we need to give it transformation rules, starting with
|
||||
its evaluation rule. When we evaluate an application of the `xla_call`
|
||||
primitive, we want to stage out out the computation to XLA. That involves
|
||||
primitive, we want to stage out the computation to XLA. That involves
|
||||
translating the jaxpr to an XLA HLO program, transferring the argument values
|
||||
to the XLA device, executing the XLA program, and transferring back the
|
||||
results. We'll cache the XLA HLO compilation so that for each `jit`ted
|
||||
|
@ -1517,7 +1517,7 @@ xla_call_p = Primitive('xla_call')
|
||||
|
||||
# With any new primitive, we need to give it transformation rules, starting with
|
||||
# its evaluation rule. When we evaluate an application of the `xla_call`
|
||||
# primitive, we want to stage out out the computation to XLA. That involves
|
||||
# primitive, we want to stage out the computation to XLA. That involves
|
||||
# translating the jaxpr to an XLA HLO program, transferring the argument values
|
||||
# to the XLA device, executing the XLA program, and transferring back the
|
||||
# results. We'll cache the XLA HLO compilation so that for each `jit`ted
|
||||
|
@ -87,7 +87,7 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
|
||||
These flags tune when to combine multiple small
|
||||
`AllGather`/`ReduceScatter`/`AllReduce` into one big
|
||||
`AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device
|
||||
communication. For example, for the the `AllGather`/`ReduceScatter` thresholds
|
||||
communication. For example, for the `AllGather`/`ReduceScatter` thresholds
|
||||
on a Transformer-based workload, consider tuning them high enough so as to
|
||||
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
|
||||
default, the `combine_threshold_bytes` is set to 256.
|
||||
|
@ -74,7 +74,7 @@ before `"hello"`. The reordering of the print side-effects breaks the illusion
|
||||
of a single-threaded execution model.
|
||||
|
||||
Another example of where side-effects can "reveal" out-of-order execution is
|
||||
when we we compile JAX programs. Consider the following JAX code:
|
||||
when we compile JAX programs. Consider the following JAX code:
|
||||
```python
|
||||
@jax.jit
|
||||
def f(x):
|
||||
|
@ -36,7 +36,7 @@ def ravel_pytree(pytree):
|
||||
A pair where the first element is a 1D array representing the flattened and
|
||||
concatenated leaf values, with dtype determined by promoting the dtypes of
|
||||
leaf values, and the second element is a callable for unflattening a 1D
|
||||
vector of the same length back to a pytree of of the same structure as the
|
||||
vector of the same length back to a pytree of the same structure as the
|
||||
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
|
||||
a convention a 1D empty array of dtype float32 is returned in the first
|
||||
component of the output.
|
||||
|
@ -723,7 +723,7 @@ def _eigh_cpu_gpu_lowering(
|
||||
# Therefore, we cannot yet support dynamic non-batch dimensions.
|
||||
if not is_constant_shape(operand_aval.shape[-2:]):
|
||||
raise NotImplementedError(
|
||||
"Shape polymorphism for for native lowering for eigh is implemented "
|
||||
"Shape polymorphism for native lowering for eigh is implemented "
|
||||
f"only for the batch dimensions: {operand_aval.shape}")
|
||||
|
||||
if not (subset_by_index is None or subset_by_index == (0, n)):
|
||||
@ -760,7 +760,7 @@ def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index):
|
||||
if not is_constant_dim(m):
|
||||
# TODO: maybe we can relax the check below for shape polymorphism?
|
||||
raise NotImplementedError(
|
||||
"Shape polymorphism for for native lowering for eigh is implemented "
|
||||
"Shape polymorphism for native lowering for eigh is implemented "
|
||||
f"only for the batch dimensions: {x.shape}")
|
||||
if m <= termination_size and (
|
||||
subset_by_index is None or subset_by_index == (0, n)
|
||||
|
@ -288,7 +288,7 @@ def pshuffle(x, axis_name, perm):
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
perm: list of of ints encoding sources for the permutation to be applied to
|
||||
perm: list of ints encoding sources for the permutation to be applied to
|
||||
the axis named ``axis_name``, so that the output at axis index i
|
||||
comes from the input at axis index perm[i]. Every integer in [0, N) should
|
||||
be included exactly once for axis size N.
|
||||
|
@ -1028,7 +1028,7 @@ def dirichlet(key: KeyArrayLike,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Dirichlet random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the probability density function:
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1}
|
||||
@ -1099,7 +1099,7 @@ def exponential(key: KeyArrayLike,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Exponential random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the probability density function:
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x) = e^{-x}
|
||||
@ -1267,7 +1267,7 @@ def gamma(key: KeyArrayLike,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample Gamma random values with given shape and float dtype.
|
||||
|
||||
The values are distributed according the probability density function:
|
||||
The values are distributed according to the probability density function:
|
||||
|
||||
.. math::
|
||||
f(x;a) \propto x^{a - 1} e^{-x}
|
||||
|
@ -135,7 +135,7 @@ def with_jax_dtype_defaults(func, use_defaults=True):
|
||||
Args:
|
||||
use_defaults : whether to convert any given output to the default dtype. May be
|
||||
a single boolean, in which case it specifies the conversion for all outputs,
|
||||
or may be a a pytree with the same structure as the function output.
|
||||
or may be a pytree with the same structure as the function output.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
|
@ -384,7 +384,7 @@ class custom_partitioning:
|
||||
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
|
||||
|
||||
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
|
||||
However, in this case, the HLO of ``my_fft`` does show a a dynamic-slice, since the last dimension
|
||||
However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension
|
||||
is the dimension along which FFTs are calculated and needs to be replicated on all devices before
|
||||
the computation can be done.
|
||||
|
||||
|
@ -1304,7 +1304,7 @@ def _instantiate_zeros(tan, arg):
|
||||
tan: the tangent.
|
||||
arg: the argument for which we need to instantiate the tangent
|
||||
|
||||
Returns: tan if is is not ad.Zero, otherwise a 0 array of appropriate type
|
||||
Returns: tan if it is not ad.Zero, otherwise a 0 array of appropriate type
|
||||
and shape
|
||||
"""
|
||||
if type(tan) is not ad.Zero:
|
||||
|
@ -47,7 +47,7 @@ def convert_and_save_model(
|
||||
"""Convert a JAX function and saves a SavedModel.
|
||||
|
||||
This is an example, we do not promise backwards compatibility for this code.
|
||||
For serious uses, please copy and and expand it as needed (see note at the top
|
||||
For serious uses, please copy and expand it as needed (see note at the top
|
||||
of the module).
|
||||
|
||||
Use this function if you have a trained ML model that has both a prediction
|
||||
|
@ -459,7 +459,7 @@ tf_impl_no_xla[lax.dot_general_p] = _dot_general
|
||||
def _interior_padding(operand, padding_value, padding_config, operand_shape):
|
||||
# Used only when enable_xla=False
|
||||
# Applies only the interior padding from the padding_config.
|
||||
# We do this somewhat inefficiently, as as a scatter.
|
||||
# We do this somewhat inefficiently, as a scatter.
|
||||
# For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where
|
||||
# f is the dilation factor for the dimension, i.e., 1 + interior_padding.
|
||||
# Then we compute the cartesian production of the indices (using broadcast
|
||||
|
@ -1165,7 +1165,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
@classmethod
|
||||
def select_and_gather_add(cls, harness):
|
||||
return [
|
||||
# This JAX primitives is not not exposed directly in the JAX API
|
||||
# This JAX primitives is not exposed directly in the JAX API
|
||||
# but arises from JVP of `lax.reduce_window` for reducers
|
||||
# `lax.max` or `lax.min`. It also arises from second-order
|
||||
# VJP of the same. Implemented using XlaReduceWindow.
|
||||
|
@ -559,7 +559,7 @@ def _csr_matmat(data: Array, indices: Array, indptr: Array, B: Array,
|
||||
|
||||
Returns:
|
||||
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
||||
representing the matrix-matrix product product.
|
||||
representing the matrix-matrix product.
|
||||
"""
|
||||
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
|
||||
|
||||
|
@ -345,7 +345,7 @@ def _svqb(X):
|
||||
tau = jnp.finfo(X.dtype).eps * w[0]
|
||||
padded = jnp.maximum(w, tau)
|
||||
|
||||
# Note the the tau == 0 edge case where X was all zeros.
|
||||
# Note the tau == 0 edge case where X was all zeros.
|
||||
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5)
|
||||
|
||||
# X^T X = V diag(w) V^T, so
|
||||
@ -391,7 +391,7 @@ def _project_out(basis, U):
|
||||
#
|
||||
# Interspersing with orthonormalization isn't directly grounded in the
|
||||
# original analysis, but taken from Algorithm 5 of [3]. In practice, due to
|
||||
# normalization, I have noticed that that the orthonormalized basis
|
||||
# normalization, I have noticed that the orthonormalized basis
|
||||
# does not always end up as a subspace of the starting basis in practice.
|
||||
# There may be room to refine this procedure further, but the adjustment
|
||||
# in the subsequent block handles this edge case well enough for now.
|
||||
|
@ -3189,7 +3189,7 @@ xla::Array<Value> retileToReducedSublanes(
|
||||
xla::Array<Value> dst_vreg_array(
|
||||
dst_layout.tileArrayShape(value_shape, target_shape));
|
||||
|
||||
// We need to rotate each src tile in each src vreg once so that that they can
|
||||
// We need to rotate each src tile in each src vreg once so that they can
|
||||
// be merged to form new vregs. If a src vreg contains more than one src tile,
|
||||
// it will be rotated once per src tile. Consider (8,512) tensor stored with
|
||||
// layout (8,128) in a vreg array of shape (1, 4). Each src vreg
|
||||
|
@ -966,7 +966,7 @@ def retile_to_reduced_sublanes(
|
||||
dst_layout.tile_array_shape(value_shape), dtype=object
|
||||
)
|
||||
|
||||
# We need to rotate each src tile in each src vreg once so that that they can
|
||||
# We need to rotate each src tile in each src vreg once so that they can
|
||||
# be merged to form new vregs. If a src vreg contains more than one src tile,
|
||||
# it will be rotated once per src tile. Consider (8,512) tensor stored with
|
||||
# layout (8,128) in a vreg array of shape (1, 4). Each src vreg
|
||||
|
@ -8083,7 +8083,7 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_nondiff_arg_tracer_error(self):
|
||||
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
|
||||
# we're testing for the error message that that usage pattern now raises.
|
||||
# we're testing for the error message that usage pattern now raises.
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(0,))
|
||||
def f(x, y):
|
||||
|
@ -1445,7 +1445,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
expected = jnp.array([5, 2, 3, 3])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# test with negative segment ids and without without explicit num_segments
|
||||
# test with negative segment ids and without explicit num_segments
|
||||
# such as num_segments is defined by the smaller index.
|
||||
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
|
||||
ans = ops.segment_sum(data, segment_ids)
|
||||
|
@ -593,7 +593,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
# lax_tests.py as an example. If you add a test here (e.g. testing
|
||||
# lowering of a key-dtyped shaped array), consider whether it
|
||||
# might also be a more general test of opaque element types. If
|
||||
# so, add a corresponding test to to CustomElementTypesTest as well.
|
||||
# so, add a corresponding test to CustomElementTypesTest as well.
|
||||
|
||||
def assertKeysEqual(self, key1, key2):
|
||||
self.assertEqual(key1.dtype, key2.dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user