Support transfer guard in broadcast_one_to_all(). Fixes https://github.com/jax-ml/jax/issues/25325

PiperOrigin-RevId: 703666450
This commit is contained in:
Danijar Hafner 2024-12-06 17:44:52 -08:00 committed by jax authors
parent baedb62b71
commit 861115ad4b

View File

@ -75,7 +75,7 @@ def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
return host_local_array_to_global_array(inp, global_mesh, pspec)
def post_jit(x):
return np.asarray(x.addressable_data(0))
return jax.device_get(x.addressable_data(0))
in_tree = jax.tree.map(pre_jit, in_tree)
out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(