mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support transfer guard in broadcast_one_to_all(). Fixes https://github.com/jax-ml/jax/issues/25325
PiperOrigin-RevId: 703666450
This commit is contained in:
parent
baedb62b71
commit
861115ad4b
@ -75,7 +75,7 @@ def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
|
||||
return host_local_array_to_global_array(inp, global_mesh, pspec)
|
||||
|
||||
def post_jit(x):
|
||||
return np.asarray(x.addressable_data(0))
|
||||
return jax.device_get(x.addressable_data(0))
|
||||
|
||||
in_tree = jax.tree.map(pre_jit, in_tree)
|
||||
out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(
|
||||
|
Loading…
x
Reference in New Issue
Block a user