diff --git a/CHANGELOG.md b/CHANGELOG.md index c1dd67b58..25f757f81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.26 -* Deprecations +* Deprecations & Removals * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. * The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are @@ -16,6 +16,8 @@ Remember to align the itemized text with the first line of an item within a list `spmd_axis_name` argument for expressing SPMD device-parallel computations. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` that cannot be converted to a JAX array now results in an exception. + * The deprecated flag `jax_parallel_functions_output_gda` has been removed. + This flag was long deprecated and did nothing; its use was a no-op. ## jaxlib 0.4.26 diff --git a/jax/_src/config.py b/jax/_src/config.py index 3ac3649ed..e6c2fdde7 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -987,11 +987,6 @@ log_checkpoint_residuals = define_bool_state( 'partially evaluated (e.g. for autodiff), printing what residuals ' 'are saved.')) -parallel_functions_output_gda = define_bool_state( - name='jax_parallel_functions_output_gda', - default=False, - help='If True, pjit will output GDAs.') - pmap_shmap_merge = define_bool_state( name='jax_pmap_shmap_merge', default=False, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b18af885d..dc759333f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1113,9 +1113,8 @@ class InputsHandler: class ResultsHandler: - # `out_avals` is the `Array` global avals when using pjit or xmap - # with `config.parallel_functions_output_gda=True`. It is the local one - # otherwise, and also when using `pmap`. + # `out_avals` is the `Array` global avals when using pjit or xmap. It is the + # local one when using `pmap`. __slots__ = ("handlers", "out_shardings", "out_avals") def __init__(self, handlers, out_shardings, out_avals):