Roll forward #22076

Reverts a386af446caeb0eb908e9c72418b91b55408a573

PiperOrigin-RevId: 651118051
This commit is contained in:
Junwhan Ahn 2024-07-10 12:41:12 -07:00 committed by jax authors
parent 4f394828e1
commit b578b869d2

View File

@ -57,6 +57,9 @@ def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
A pytree matching in_tree where the leaves now all contain the data from the
first host.
"""
if jax.process_count() == 1:
return jax.tree.map(np.asarray, in_tree)
if is_source is None:
is_source = jax.process_index() == 0