Lift lambda x: x to the top level so that we don't recompile on every invocation of process_allgather.

PiperOrigin-RevId: 480155482
This commit is contained in:
Yash Katariya 2022-10-10 12:50:41 -07:00 committed by jax authors
parent 90e9abe278
commit 752c3ffcd9

View File

@ -82,11 +82,17 @@ def sync_global_devices(name: str):
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
# Identity function is at the top level so that `process_allgather` doesn't
# recompile on every invocation.
def _identity_fn(x):
return x
def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding.OpShardingSharding(inp.sharding._device_assignment,
sharding._get_replicated_op_sharding())
out = pjit(lambda x: x, out_axis_resources=reps)(inp)
out = pjit(_identity_fn, out_axis_resources=reps)(inp)
else:
# All inputs here will be fully addressable.
devices = np.array(jax.devices()).reshape(jax.process_count(),
@ -107,7 +113,7 @@ def _handle_array_process_allgather(inp, tiled):
global_arr = array.make_array_from_single_device_arrays(
global_aval.shape, s, bufs)
with global_mesh:
out = pjit(lambda x: x, out_axis_resources=None)(global_arr)
out = pjit(_identity_fn, out_axis_resources=None)(global_arr)
return np.asarray(out.addressable_data(0))
@ -153,7 +159,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
inp = np.expand_dims(inp, axis=0)
with global_mesh:
out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
out = pjit(_identity_fn, in_axis_resources=in_axis_resources,
out_axis_resources=None)(inp)
return np.asarray(out.addressable_data(0))