Merge some loops in device_put since it's trivial to do so

PiperOrigin-RevId: 626546322
This commit is contained in:
Yash Katariya 2024-04-19 20:59:05 -07:00 committed by jax authors
parent 0943eb385b
commit 1837b436d7

View File

@ -2439,7 +2439,7 @@ def make_jaxpr(fun: Callable,
def _infer_src_sharding(src, x) -> Sharding | None:
if src is not None:
return src
return src # type: ignore
if isinstance(x, array.ArrayImpl):
return x.sharding
elif isinstance(x, core.Tracer):
@ -2493,21 +2493,20 @@ def device_put(
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
(src is None or
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
for leaf in tree_leaves(x):
_check_sharding(shaped_abstractify(leaf), s=device)
return tree_map(
lambda y: dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y)), x)
def _map(y):
_check_sharding(shaped_abstractify(y), s=device)
return dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y))
return tree_map(_map, x)
x_flat, treedef = tree_flatten(x)
device_flat = flatten_axes("device_put device", treedef, device)
src_flat = flatten_axes("device_put source", treedef, src)
for x_leaf, device_leaf in zip(x_flat, device_flat):
_check_sharding(shaped_abstractify(x_leaf), device_leaf)
out_flat = [
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
for xf, d, s in zip(x_flat, device_flat, src_flat)
]
out_flat = []
for xf, d, s in zip(x_flat, device_flat, src_flat):
_check_sharding(shaped_abstractify(xf), d)
out_flat.append(dispatch.device_put_p.bind(
xf, device=d, src=_infer_src_sharding(s, xf)))
return tree_unflatten(treedef, out_flat)