mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove --jax_parallel_functions_output_gda.
PiperOrigin-RevId: 616898032
This commit is contained in:
parent
56451d1f56
commit
ee2631e4da
@ -8,7 +8,7 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
## jax 0.4.26
|
## jax 0.4.26
|
||||||
|
|
||||||
* Deprecations
|
* Deprecations & Removals
|
||||||
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
|
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
|
||||||
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
|
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
|
||||||
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
|
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
|
||||||
@ -16,6 +16,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
|
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
|
||||||
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
|
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
|
||||||
that cannot be converted to a JAX array now results in an exception.
|
that cannot be converted to a JAX array now results in an exception.
|
||||||
|
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
|
||||||
|
This flag was long deprecated and did nothing; its use was a no-op.
|
||||||
|
|
||||||
## jaxlib 0.4.26
|
## jaxlib 0.4.26
|
||||||
|
|
||||||
|
@ -987,11 +987,6 @@ log_checkpoint_residuals = define_bool_state(
|
|||||||
'partially evaluated (e.g. for autodiff), printing what residuals '
|
'partially evaluated (e.g. for autodiff), printing what residuals '
|
||||||
'are saved.'))
|
'are saved.'))
|
||||||
|
|
||||||
parallel_functions_output_gda = define_bool_state(
|
|
||||||
name='jax_parallel_functions_output_gda',
|
|
||||||
default=False,
|
|
||||||
help='If True, pjit will output GDAs.')
|
|
||||||
|
|
||||||
pmap_shmap_merge = define_bool_state(
|
pmap_shmap_merge = define_bool_state(
|
||||||
name='jax_pmap_shmap_merge',
|
name='jax_pmap_shmap_merge',
|
||||||
default=False,
|
default=False,
|
||||||
|
@ -1113,9 +1113,8 @@ class InputsHandler:
|
|||||||
|
|
||||||
|
|
||||||
class ResultsHandler:
|
class ResultsHandler:
|
||||||
# `out_avals` is the `Array` global avals when using pjit or xmap
|
# `out_avals` is the `Array` global avals when using pjit or xmap. It is the
|
||||||
# with `config.parallel_functions_output_gda=True`. It is the local one
|
# local one when using `pmap`.
|
||||||
# otherwise, and also when using `pmap`.
|
|
||||||
__slots__ = ("handlers", "out_shardings", "out_avals")
|
__slots__ = ("handlers", "out_shardings", "out_avals")
|
||||||
|
|
||||||
def __init__(self, handlers, out_shardings, out_avals):
|
def __init__(self, handlers, out_shardings, out_avals):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user