mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge some loops in device_put since it's trivial to do so
PiperOrigin-RevId: 626546322
This commit is contained in:
parent
0943eb385b
commit
1837b436d7
@ -2439,7 +2439,7 @@ def make_jaxpr(fun: Callable,
|
||||
|
||||
def _infer_src_sharding(src, x) -> Sharding | None:
|
||||
if src is not None:
|
||||
return src
|
||||
return src # type: ignore
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
return x.sharding
|
||||
elif isinstance(x, core.Tracer):
|
||||
@ -2493,21 +2493,20 @@ def device_put(
|
||||
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
|
||||
(src is None or
|
||||
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
|
||||
for leaf in tree_leaves(x):
|
||||
_check_sharding(shaped_abstractify(leaf), s=device)
|
||||
return tree_map(
|
||||
lambda y: dispatch.device_put_p.bind(
|
||||
y, device=device, src=_infer_src_sharding(src, y)), x)
|
||||
def _map(y):
|
||||
_check_sharding(shaped_abstractify(y), s=device)
|
||||
return dispatch.device_put_p.bind(
|
||||
y, device=device, src=_infer_src_sharding(src, y))
|
||||
return tree_map(_map, x)
|
||||
|
||||
x_flat, treedef = tree_flatten(x)
|
||||
device_flat = flatten_axes("device_put device", treedef, device)
|
||||
src_flat = flatten_axes("device_put source", treedef, src)
|
||||
for x_leaf, device_leaf in zip(x_flat, device_flat):
|
||||
_check_sharding(shaped_abstractify(x_leaf), device_leaf)
|
||||
out_flat = [
|
||||
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
|
||||
for xf, d, s in zip(x_flat, device_flat, src_flat)
|
||||
]
|
||||
out_flat = []
|
||||
for xf, d, s in zip(x_flat, device_flat, src_flat):
|
||||
_check_sharding(shaped_abstractify(xf), d)
|
||||
out_flat.append(dispatch.device_put_p.bind(
|
||||
xf, device=d, src=_infer_src_sharding(s, xf)))
|
||||
return tree_unflatten(treedef, out_flat)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user