mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Lift lambda x: x
to the top level so that we don't recompile on every invocation of process_allgather
.
PiperOrigin-RevId: 480155482
This commit is contained in:
parent
90e9abe278
commit
752c3ffcd9
@ -82,11 +82,17 @@ def sync_global_devices(name: str):
|
||||
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
|
||||
|
||||
|
||||
# Identity function is at the top level so that `process_allgather` doesn't
|
||||
# recompile on every invocation.
|
||||
def _identity_fn(x):
|
||||
return x
|
||||
|
||||
|
||||
def _handle_array_process_allgather(inp, tiled):
|
||||
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
|
||||
reps = sharding.OpShardingSharding(inp.sharding._device_assignment,
|
||||
sharding._get_replicated_op_sharding())
|
||||
out = pjit(lambda x: x, out_axis_resources=reps)(inp)
|
||||
out = pjit(_identity_fn, out_axis_resources=reps)(inp)
|
||||
else:
|
||||
# All inputs here will be fully addressable.
|
||||
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
||||
@ -107,7 +113,7 @@ def _handle_array_process_allgather(inp, tiled):
|
||||
global_arr = array.make_array_from_single_device_arrays(
|
||||
global_aval.shape, s, bufs)
|
||||
with global_mesh:
|
||||
out = pjit(lambda x: x, out_axis_resources=None)(global_arr)
|
||||
out = pjit(_identity_fn, out_axis_resources=None)(global_arr)
|
||||
|
||||
return np.asarray(out.addressable_data(0))
|
||||
|
||||
@ -153,7 +159,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
|
||||
inp = np.expand_dims(inp, axis=0)
|
||||
|
||||
with global_mesh:
|
||||
out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
|
||||
out = pjit(_identity_fn, in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=None)(inp)
|
||||
return np.asarray(out.addressable_data(0))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user