mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Roll forward #22076
Reverts a386af446caeb0eb908e9c72418b91b55408a573 PiperOrigin-RevId: 651118051
This commit is contained in:
parent
4f394828e1
commit
b578b869d2
@ -57,6 +57,9 @@ def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
|
||||
A pytree matching in_tree where the leaves now all contain the data from the
|
||||
first host.
|
||||
"""
|
||||
if jax.process_count() == 1:
|
||||
return jax.tree.map(np.asarray, in_tree)
|
||||
|
||||
if is_source is None:
|
||||
is_source = jax.process_index() == 0
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user