mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix duplicate word occurrences
This commit is contained in:
parent
e82807297b
commit
4da56dcdd7
@ -117,7 +117,7 @@
|
|||||||
"id": "iZgTmx5pFd6z"
|
"id": "iZgTmx5pFd6z"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"But `pmap` and `vmap` differ in in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
|
"But `pmap` and `vmap` differ in how those values are computed: where `vmap` vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), `pmap` instead replicates the function and executes each replica on its own XLA device in parallel."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1938,7 +1938,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"With any new primitive, we need to give it transformation rules, starting with\n",
|
"With any new primitive, we need to give it transformation rules, starting with\n",
|
||||||
"its evaluation rule. When we evaluate an application of the `xla_call`\n",
|
"its evaluation rule. When we evaluate an application of the `xla_call`\n",
|
||||||
"primitive, we want to stage out out the computation to XLA. That involves\n",
|
"primitive, we want to stage out the computation to XLA. That involves\n",
|
||||||
"translating the jaxpr to an XLA HLO program, transferring the argument values\n",
|
"translating the jaxpr to an XLA HLO program, transferring the argument values\n",
|
||||||
"to the XLA device, executing the XLA program, and transferring back the\n",
|
"to the XLA device, executing the XLA program, and transferring back the\n",
|
||||||
"results. We'll cache the XLA HLO compilation so that for each `jit`ted\n",
|
"results. We'll cache the XLA HLO compilation so that for each `jit`ted\n",
|
||||||
|
@ -1524,7 +1524,7 @@ xla_call_p = Primitive('xla_call')
|
|||||||
|
|
||||||
With any new primitive, we need to give it transformation rules, starting with
|
With any new primitive, we need to give it transformation rules, starting with
|
||||||
its evaluation rule. When we evaluate an application of the `xla_call`
|
its evaluation rule. When we evaluate an application of the `xla_call`
|
||||||
primitive, we want to stage out out the computation to XLA. That involves
|
primitive, we want to stage out the computation to XLA. That involves
|
||||||
translating the jaxpr to an XLA HLO program, transferring the argument values
|
translating the jaxpr to an XLA HLO program, transferring the argument values
|
||||||
to the XLA device, executing the XLA program, and transferring back the
|
to the XLA device, executing the XLA program, and transferring back the
|
||||||
results. We'll cache the XLA HLO compilation so that for each `jit`ted
|
results. We'll cache the XLA HLO compilation so that for each `jit`ted
|
||||||
|
@ -1517,7 +1517,7 @@ xla_call_p = Primitive('xla_call')
|
|||||||
|
|
||||||
# With any new primitive, we need to give it transformation rules, starting with
|
# With any new primitive, we need to give it transformation rules, starting with
|
||||||
# its evaluation rule. When we evaluate an application of the `xla_call`
|
# its evaluation rule. When we evaluate an application of the `xla_call`
|
||||||
# primitive, we want to stage out out the computation to XLA. That involves
|
# primitive, we want to stage out the computation to XLA. That involves
|
||||||
# translating the jaxpr to an XLA HLO program, transferring the argument values
|
# translating the jaxpr to an XLA HLO program, transferring the argument values
|
||||||
# to the XLA device, executing the XLA program, and transferring back the
|
# to the XLA device, executing the XLA program, and transferring back the
|
||||||
# results. We'll cache the XLA HLO compilation so that for each `jit`ted
|
# results. We'll cache the XLA HLO compilation so that for each `jit`ted
|
||||||
|
@ -87,7 +87,7 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta
|
|||||||
These flags tune when to combine multiple small
|
These flags tune when to combine multiple small
|
||||||
`AllGather`/`ReduceScatter`/`AllReduce` into one big
|
`AllGather`/`ReduceScatter`/`AllReduce` into one big
|
||||||
`AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device
|
`AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device
|
||||||
communication. For example, for the the `AllGather`/`ReduceScatter` thresholds
|
communication. For example, for the `AllGather`/`ReduceScatter` thresholds
|
||||||
on a Transformer-based workload, consider tuning them high enough so as to
|
on a Transformer-based workload, consider tuning them high enough so as to
|
||||||
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
|
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
|
||||||
default, the `combine_threshold_bytes` is set to 256.
|
default, the `combine_threshold_bytes` is set to 256.
|
||||||
|
@ -74,7 +74,7 @@ before `"hello"`. The reordering of the print side-effects breaks the illusion
|
|||||||
of a single-threaded execution model.
|
of a single-threaded execution model.
|
||||||
|
|
||||||
Another example of where side-effects can "reveal" out-of-order execution is
|
Another example of where side-effects can "reveal" out-of-order execution is
|
||||||
when we we compile JAX programs. Consider the following JAX code:
|
when we compile JAX programs. Consider the following JAX code:
|
||||||
```python
|
```python
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def f(x):
|
def f(x):
|
||||||
|
@ -36,7 +36,7 @@ def ravel_pytree(pytree):
|
|||||||
A pair where the first element is a 1D array representing the flattened and
|
A pair where the first element is a 1D array representing the flattened and
|
||||||
concatenated leaf values, with dtype determined by promoting the dtypes of
|
concatenated leaf values, with dtype determined by promoting the dtypes of
|
||||||
leaf values, and the second element is a callable for unflattening a 1D
|
leaf values, and the second element is a callable for unflattening a 1D
|
||||||
vector of the same length back to a pytree of of the same structure as the
|
vector of the same length back to a pytree of the same structure as the
|
||||||
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
|
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
|
||||||
a convention a 1D empty array of dtype float32 is returned in the first
|
a convention a 1D empty array of dtype float32 is returned in the first
|
||||||
component of the output.
|
component of the output.
|
||||||
|
@ -723,7 +723,7 @@ def _eigh_cpu_gpu_lowering(
|
|||||||
# Therefore, we cannot yet support dynamic non-batch dimensions.
|
# Therefore, we cannot yet support dynamic non-batch dimensions.
|
||||||
if not is_constant_shape(operand_aval.shape[-2:]):
|
if not is_constant_shape(operand_aval.shape[-2:]):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shape polymorphism for for native lowering for eigh is implemented "
|
"Shape polymorphism for native lowering for eigh is implemented "
|
||||||
f"only for the batch dimensions: {operand_aval.shape}")
|
f"only for the batch dimensions: {operand_aval.shape}")
|
||||||
|
|
||||||
if not (subset_by_index is None or subset_by_index == (0, n)):
|
if not (subset_by_index is None or subset_by_index == (0, n)):
|
||||||
@ -760,7 +760,7 @@ def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index):
|
|||||||
if not is_constant_dim(m):
|
if not is_constant_dim(m):
|
||||||
# TODO: maybe we can relax the check below for shape polymorphism?
|
# TODO: maybe we can relax the check below for shape polymorphism?
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shape polymorphism for for native lowering for eigh is implemented "
|
"Shape polymorphism for native lowering for eigh is implemented "
|
||||||
f"only for the batch dimensions: {x.shape}")
|
f"only for the batch dimensions: {x.shape}")
|
||||||
if m <= termination_size and (
|
if m <= termination_size and (
|
||||||
subset_by_index is None or subset_by_index == (0, n)
|
subset_by_index is None or subset_by_index == (0, n)
|
||||||
|
@ -288,7 +288,7 @@ def pshuffle(x, axis_name, perm):
|
|||||||
x: array(s) with a mapped axis named ``axis_name``.
|
x: array(s) with a mapped axis named ``axis_name``.
|
||||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||||
:func:`jax.pmap` documentation for more details).
|
:func:`jax.pmap` documentation for more details).
|
||||||
perm: list of of ints encoding sources for the permutation to be applied to
|
perm: list of ints encoding sources for the permutation to be applied to
|
||||||
the axis named ``axis_name``, so that the output at axis index i
|
the axis named ``axis_name``, so that the output at axis index i
|
||||||
comes from the input at axis index perm[i]. Every integer in [0, N) should
|
comes from the input at axis index perm[i]. Every integer in [0, N) should
|
||||||
be included exactly once for axis size N.
|
be included exactly once for axis size N.
|
||||||
|
@ -1028,7 +1028,7 @@ def dirichlet(key: KeyArrayLike,
|
|||||||
dtype: DTypeLikeFloat = float) -> Array:
|
dtype: DTypeLikeFloat = float) -> Array:
|
||||||
r"""Sample Dirichlet random values with given shape and float dtype.
|
r"""Sample Dirichlet random values with given shape and float dtype.
|
||||||
|
|
||||||
The values are distributed according the probability density function:
|
The values are distributed according to the probability density function:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1}
|
f(\{x_i\}; \{\alpha_i\}) = \propto \prod_{i=1}^k x_i^{\alpha_i - 1}
|
||||||
@ -1099,7 +1099,7 @@ def exponential(key: KeyArrayLike,
|
|||||||
dtype: DTypeLikeFloat = float) -> Array:
|
dtype: DTypeLikeFloat = float) -> Array:
|
||||||
r"""Sample Exponential random values with given shape and float dtype.
|
r"""Sample Exponential random values with given shape and float dtype.
|
||||||
|
|
||||||
The values are distributed according the probability density function:
|
The values are distributed according to the probability density function:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
f(x) = e^{-x}
|
f(x) = e^{-x}
|
||||||
@ -1267,7 +1267,7 @@ def gamma(key: KeyArrayLike,
|
|||||||
dtype: DTypeLikeFloat = float) -> Array:
|
dtype: DTypeLikeFloat = float) -> Array:
|
||||||
r"""Sample Gamma random values with given shape and float dtype.
|
r"""Sample Gamma random values with given shape and float dtype.
|
||||||
|
|
||||||
The values are distributed according the probability density function:
|
The values are distributed according to the probability density function:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
f(x;a) \propto x^{a - 1} e^{-x}
|
f(x;a) \propto x^{a - 1} e^{-x}
|
||||||
|
@ -135,7 +135,7 @@ def with_jax_dtype_defaults(func, use_defaults=True):
|
|||||||
Args:
|
Args:
|
||||||
use_defaults : whether to convert any given output to the default dtype. May be
|
use_defaults : whether to convert any given output to the default dtype. May be
|
||||||
a single boolean, in which case it specifies the conversion for all outputs,
|
a single boolean, in which case it specifies the conversion for all outputs,
|
||||||
or may be a a pytree with the same structure as the function output.
|
or may be a pytree with the same structure as the function output.
|
||||||
"""
|
"""
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
@ -384,7 +384,7 @@ class custom_partitioning:
|
|||||||
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
|
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
|
||||||
|
|
||||||
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
|
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
|
||||||
However, in this case, the HLO of ``my_fft`` does show a a dynamic-slice, since the last dimension
|
However, in this case, the HLO of ``my_fft`` does show a dynamic-slice, since the last dimension
|
||||||
is the dimension along which FFTs are calculated and needs to be replicated on all devices before
|
is the dimension along which FFTs are calculated and needs to be replicated on all devices before
|
||||||
the computation can be done.
|
the computation can be done.
|
||||||
|
|
||||||
|
@ -1304,7 +1304,7 @@ def _instantiate_zeros(tan, arg):
|
|||||||
tan: the tangent.
|
tan: the tangent.
|
||||||
arg: the argument for which we need to instantiate the tangent
|
arg: the argument for which we need to instantiate the tangent
|
||||||
|
|
||||||
Returns: tan if is is not ad.Zero, otherwise a 0 array of appropriate type
|
Returns: tan if it is not ad.Zero, otherwise a 0 array of appropriate type
|
||||||
and shape
|
and shape
|
||||||
"""
|
"""
|
||||||
if type(tan) is not ad.Zero:
|
if type(tan) is not ad.Zero:
|
||||||
|
@ -47,7 +47,7 @@ def convert_and_save_model(
|
|||||||
"""Convert a JAX function and saves a SavedModel.
|
"""Convert a JAX function and saves a SavedModel.
|
||||||
|
|
||||||
This is an example, we do not promise backwards compatibility for this code.
|
This is an example, we do not promise backwards compatibility for this code.
|
||||||
For serious uses, please copy and and expand it as needed (see note at the top
|
For serious uses, please copy and expand it as needed (see note at the top
|
||||||
of the module).
|
of the module).
|
||||||
|
|
||||||
Use this function if you have a trained ML model that has both a prediction
|
Use this function if you have a trained ML model that has both a prediction
|
||||||
|
@ -459,7 +459,7 @@ tf_impl_no_xla[lax.dot_general_p] = _dot_general
|
|||||||
def _interior_padding(operand, padding_value, padding_config, operand_shape):
|
def _interior_padding(operand, padding_value, padding_config, operand_shape):
|
||||||
# Used only when enable_xla=False
|
# Used only when enable_xla=False
|
||||||
# Applies only the interior padding from the padding_config.
|
# Applies only the interior padding from the padding_config.
|
||||||
# We do this somewhat inefficiently, as as a scatter.
|
# We do this somewhat inefficiently, as a scatter.
|
||||||
# For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where
|
# For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where
|
||||||
# f is the dilation factor for the dimension, i.e., 1 + interior_padding.
|
# f is the dilation factor for the dimension, i.e., 1 + interior_padding.
|
||||||
# Then we compute the cartesian production of the indices (using broadcast
|
# Then we compute the cartesian production of the indices (using broadcast
|
||||||
|
@ -1165,7 +1165,7 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def select_and_gather_add(cls, harness):
|
def select_and_gather_add(cls, harness):
|
||||||
return [
|
return [
|
||||||
# This JAX primitives is not not exposed directly in the JAX API
|
# This JAX primitives is not exposed directly in the JAX API
|
||||||
# but arises from JVP of `lax.reduce_window` for reducers
|
# but arises from JVP of `lax.reduce_window` for reducers
|
||||||
# `lax.max` or `lax.min`. It also arises from second-order
|
# `lax.max` or `lax.min`. It also arises from second-order
|
||||||
# VJP of the same. Implemented using XlaReduceWindow.
|
# VJP of the same. Implemented using XlaReduceWindow.
|
||||||
|
@ -559,7 +559,7 @@ def _csr_matmat(data: Array, indices: Array, indptr: Array, B: Array,
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
||||||
representing the matrix-matrix product product.
|
representing the matrix-matrix product.
|
||||||
"""
|
"""
|
||||||
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
|
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
|
||||||
|
|
||||||
|
@ -345,7 +345,7 @@ def _svqb(X):
|
|||||||
tau = jnp.finfo(X.dtype).eps * w[0]
|
tau = jnp.finfo(X.dtype).eps * w[0]
|
||||||
padded = jnp.maximum(w, tau)
|
padded = jnp.maximum(w, tau)
|
||||||
|
|
||||||
# Note the the tau == 0 edge case where X was all zeros.
|
# Note the tau == 0 edge case where X was all zeros.
|
||||||
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5)
|
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5)
|
||||||
|
|
||||||
# X^T X = V diag(w) V^T, so
|
# X^T X = V diag(w) V^T, so
|
||||||
@ -391,7 +391,7 @@ def _project_out(basis, U):
|
|||||||
#
|
#
|
||||||
# Interspersing with orthonormalization isn't directly grounded in the
|
# Interspersing with orthonormalization isn't directly grounded in the
|
||||||
# original analysis, but taken from Algorithm 5 of [3]. In practice, due to
|
# original analysis, but taken from Algorithm 5 of [3]. In practice, due to
|
||||||
# normalization, I have noticed that that the orthonormalized basis
|
# normalization, I have noticed that the orthonormalized basis
|
||||||
# does not always end up as a subspace of the starting basis in practice.
|
# does not always end up as a subspace of the starting basis in practice.
|
||||||
# There may be room to refine this procedure further, but the adjustment
|
# There may be room to refine this procedure further, but the adjustment
|
||||||
# in the subsequent block handles this edge case well enough for now.
|
# in the subsequent block handles this edge case well enough for now.
|
||||||
|
@ -3189,7 +3189,7 @@ xla::Array<Value> retileToReducedSublanes(
|
|||||||
xla::Array<Value> dst_vreg_array(
|
xla::Array<Value> dst_vreg_array(
|
||||||
dst_layout.tileArrayShape(value_shape, target_shape));
|
dst_layout.tileArrayShape(value_shape, target_shape));
|
||||||
|
|
||||||
// We need to rotate each src tile in each src vreg once so that that they can
|
// We need to rotate each src tile in each src vreg once so that they can
|
||||||
// be merged to form new vregs. If a src vreg contains more than one src tile,
|
// be merged to form new vregs. If a src vreg contains more than one src tile,
|
||||||
// it will be rotated once per src tile. Consider (8,512) tensor stored with
|
// it will be rotated once per src tile. Consider (8,512) tensor stored with
|
||||||
// layout (8,128) in a vreg array of shape (1, 4). Each src vreg
|
// layout (8,128) in a vreg array of shape (1, 4). Each src vreg
|
||||||
|
@ -966,7 +966,7 @@ def retile_to_reduced_sublanes(
|
|||||||
dst_layout.tile_array_shape(value_shape), dtype=object
|
dst_layout.tile_array_shape(value_shape), dtype=object
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to rotate each src tile in each src vreg once so that that they can
|
# We need to rotate each src tile in each src vreg once so that they can
|
||||||
# be merged to form new vregs. If a src vreg contains more than one src tile,
|
# be merged to form new vregs. If a src vreg contains more than one src tile,
|
||||||
# it will be rotated once per src tile. Consider (8,512) tensor stored with
|
# it will be rotated once per src tile. Consider (8,512) tensor stored with
|
||||||
# layout (8,128) in a vreg array of shape (1, 4). Each src vreg
|
# layout (8,128) in a vreg array of shape (1, 4). Each src vreg
|
||||||
|
@ -8083,7 +8083,7 @@ class CustomVJPTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
def test_nondiff_arg_tracer_error(self):
|
def test_nondiff_arg_tracer_error(self):
|
||||||
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
|
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
|
||||||
# we're testing for the error message that that usage pattern now raises.
|
# we're testing for the error message that usage pattern now raises.
|
||||||
|
|
||||||
@partial(jax.custom_vjp, nondiff_argnums=(0,))
|
@partial(jax.custom_vjp, nondiff_argnums=(0,))
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
|
@ -1445,7 +1445,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
|||||||
expected = jnp.array([5, 2, 3, 3])
|
expected = jnp.array([5, 2, 3, 3])
|
||||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||||
|
|
||||||
# test with negative segment ids and without without explicit num_segments
|
# test with negative segment ids and without explicit num_segments
|
||||||
# such as num_segments is defined by the smaller index.
|
# such as num_segments is defined by the smaller index.
|
||||||
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
|
segment_ids = jnp.array([3, 3, 3, 4, 5, 5, -7, -6])
|
||||||
ans = ops.segment_sum(data, segment_ids)
|
ans = ops.segment_sum(data, segment_ids)
|
||||||
|
@ -593,7 +593,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
|||||||
# lax_tests.py as an example. If you add a test here (e.g. testing
|
# lax_tests.py as an example. If you add a test here (e.g. testing
|
||||||
# lowering of a key-dtyped shaped array), consider whether it
|
# lowering of a key-dtyped shaped array), consider whether it
|
||||||
# might also be a more general test of opaque element types. If
|
# might also be a more general test of opaque element types. If
|
||||||
# so, add a corresponding test to to CustomElementTypesTest as well.
|
# so, add a corresponding test to CustomElementTypesTest as well.
|
||||||
|
|
||||||
def assertKeysEqual(self, key1, key2):
|
def assertKeysEqual(self, key1, key2):
|
||||||
self.assertEqual(key1.dtype, key2.dtype)
|
self.assertEqual(key1.dtype, key2.dtype)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user