mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix token management for ordered side-effects.
Right now, when there are multiple devices, we shall get a output token from each device, but we only keep the token from `device_0` and replicate it across devices to get input tokens for next function call with ordered side-effects. This is fine on TPU/GPU, as they are essentially executed in sequence. But on CPU, they could run in parallel, so we need to make sure the dependency is set correctly. PiperOrigin-RevId: 623296894
This commit is contained in:
parent
9809aa1929
commit
f1ae6232e9
@ -108,7 +108,7 @@ def simple_impl(prim):
|
||||
RuntimeToken = Any
|
||||
|
||||
class RuntimeTokenSet(threading.local):
|
||||
"""See docstring for effect.py module for the calling convention for tokens."""
|
||||
"""See docstring for effects.py module for the calling convention for tokens."""
|
||||
|
||||
# For each ordered effect, the token returned by the last dispatched
|
||||
# computation, sharded over the devices in that computation.
|
||||
@ -125,6 +125,16 @@ class RuntimeTokenSet(threading.local):
|
||||
def get_token_input(self, eff: core.Effect,
|
||||
devices: list[Device]) -> jax.Array:
|
||||
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
|
||||
|
||||
if isinstance(tok, jax.Array):
|
||||
# The order of devices may change, so we need to reshard if necessary.
|
||||
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
|
||||
# scenario. Revise the logic later. A distributed shutdown barrier inside
|
||||
# the XLA program may be needed.
|
||||
return jax.device_put(tok, jax.sharding.PositionalSharding(devices))
|
||||
|
||||
# We only use replicated sharding for the first time when the token for the
|
||||
# order effect hasn't been created.
|
||||
s = jax.sharding.GSPMDSharding.get_replicated(devices)
|
||||
sharded_tok = pxla.shard_args([s], [tok])[0]
|
||||
self.current_tokens[eff] = sharded_tok
|
||||
|
@ -1155,13 +1155,25 @@ class ExecuteReplicated:
|
||||
|
||||
def _handle_token_bufs(self, token_bufs, sharded_token):
|
||||
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
|
||||
# token buffer (as a singleton list).
|
||||
# token buffers.
|
||||
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
|
||||
for i, device in enumerate(self._local_devices):
|
||||
dispatch.runtime_tokens.set_output_runtime_token(
|
||||
device, sharded_token.get_token(i))
|
||||
for eff, token_buf in zip(self.ordered_effects, token_bufs):
|
||||
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
|
||||
assert len(token_buf) > 0
|
||||
if len(token_buf) == 1:
|
||||
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
|
||||
else:
|
||||
token_devices = []
|
||||
for token in token_buf:
|
||||
assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding)
|
||||
token_devices.append(token.sharding._device_assignment[0])
|
||||
s = sharding_impls.PositionalSharding(token_devices)
|
||||
global_token_array = jax.make_array_from_single_device_arrays(
|
||||
(0,), s, token_buf
|
||||
)
|
||||
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
|
||||
|
||||
@profiler.annotate_function
|
||||
def __call__(self, *args):
|
||||
|
Loading…
x
Reference in New Issue
Block a user