diff --git a/CHANGELOG.md b/CHANGELOG.md index ab4ea6fc3..324acf74f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,7 +57,7 @@ Remember to align the itemized text with the first line of an item within a list * The previously-deprecated imports `jax.interpreters.ad.config` and `jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config` and `jax.extend.source_info_util` instead. - * JAX export does not support anymore older serialization version. Version 9 + * JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). @@ -141,8 +141,8 @@ Remember to align the itemized text with the first line of an item within a list cannot interact, e.g., in arithmetic operations. Scopes are introduced by {func}`jax.experimental.jax2tf.convert`, {func}`jax.experimental.export.symbolic_shape`, {func}`jax.experimental.export.symbolic_args_specs`. - The scope of a symbolic expression `e` can be read with `e.scope` and passed in - to the above functions to direct them to construct symbolic expressions in + The scope of a symbolic expression `e` can be read with `e.scope` and passed + into the above functions to direct them to construct symbolic expressions in a given scope. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * simplified and faster equality comparisons, where we consider two symbolic dimensions @@ -323,7 +323,7 @@ Remember to align the itemized text with the first line of an item within a list * Bug fixes * Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the - cache is placed on a network filesystem such as GCS. + cache is placed on a network file system such as GCS. * The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built. diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 0193c7203..9ee151baf 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -69,7 +69,7 @@ _no_operand_sentinel = object() @api_boundary def switch(index, branches: Sequence[Callable], *operands, operand=_no_operand_sentinel): - """Apply exactly one of ``branches`` given by ``index``. + """Apply exactly one of the ``branches`` given by ``index``. If ``index`` is out of bounds, it is clamped to within bounds. diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 734037415..4d55907f6 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -49,7 +49,7 @@ def _split_root_args(args, const_lengths): @api_boundary def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): - """Differentiably solve for a roots of a function. + """Differentiably solve for the roots of a function. This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 48689bbb0..bb03f1e5a 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -232,7 +232,7 @@ class GatherDimensionNumbers(NamedTuple): in the output of the gather. Must be a tuple of integers in ascending order. start_index_map: for each dimension in `start_indices`, gives the - corresponding dimension in `operand` that is to be sliced. Must be a + corresponding dimension in the `operand` that is to be sliced. Must be a tuple of integers with size equal to `start_indices.shape[-1]`. Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is @@ -261,8 +261,8 @@ class GatherScatterMode(enum.Enum): will be discarded. PROMISE_IN_BOUNDS: The user promises that indices are in bounds. No additional checking will be - performed. In practice, with the current XLA implementation this means - that, out-of-bounds gathers will be clamped but out-of-bounds scatters will + performed. In practice, with the current XLA implementation this means + that out-of-bounds gathers will be clamped but out-of-bounds scatters will be discarded. Gradients will not be correct if indices are out-of-bounds. """ CLIP = enum.auto()