Remove --jax_parallel_functions_output_gda.

PiperOrigin-RevId: 616898032
This commit is contained in:
Peter Hawkins 2024-03-18 11:41:17 -07:00 committed by jax authors
parent 56451d1f56
commit ee2631e4da
3 changed files with 5 additions and 9 deletions

View File

@ -8,7 +8,7 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.26
* Deprecations
* Deprecations & Removals
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
@ -16,6 +16,8 @@ Remember to align the itemized text with the first line of an item within a list
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
that cannot be converted to a JAX array now results in an exception.
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
This flag was long deprecated and did nothing; its use was a no-op.
## jaxlib 0.4.26

View File

@ -987,11 +987,6 @@ log_checkpoint_residuals = define_bool_state(
'partially evaluated (e.g. for autodiff), printing what residuals '
'are saved.'))
parallel_functions_output_gda = define_bool_state(
name='jax_parallel_functions_output_gda',
default=False,
help='If True, pjit will output GDAs.')
pmap_shmap_merge = define_bool_state(
name='jax_pmap_shmap_merge',
default=False,

View File

@ -1113,9 +1113,8 @@ class InputsHandler:
class ResultsHandler:
# `out_avals` is the `Array` global avals when using pjit or xmap
# with `config.parallel_functions_output_gda=True`. It is the local one
# otherwise, and also when using `pmap`.
# `out_avals` is the `Array` global avals when using pjit or xmap. It is the
# local one when using `pmap`.
__slots__ = ("handlers", "out_shardings", "out_avals")
def __init__(self, handlers, out_shardings, out_avals):