Fix token management for ordered side-effects.

Right now, when there are multiple devices, we shall get a output token from each device, but we only keep the token from `device_0` and replicate it across devices to get input tokens for next function call with ordered side-effects. This is fine on TPU/GPU, as they are essentially executed in sequence. But on CPU, they could run in parallel, so we need to make sure the dependency is set correctly.

PiperOrigin-RevId: 623296894
This commit is contained in:
Yue Sheng 2024-04-09 15:24:40 -07:00 committed by jax authors
parent 9809aa1929
commit f1ae6232e9
2 changed files with 25 additions and 3 deletions

View File

@ -108,7 +108,7 @@ def simple_impl(prim):
RuntimeToken = Any
class RuntimeTokenSet(threading.local):
"""See docstring for effect.py module for the calling convention for tokens."""
"""See docstring for effects.py module for the calling convention for tokens."""
# For each ordered effect, the token returned by the last dispatched
# computation, sharded over the devices in that computation.
@ -125,6 +125,16 @@ class RuntimeTokenSet(threading.local):
def get_token_input(self, eff: core.Effect,
devices: list[Device]) -> jax.Array:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
if isinstance(tok, jax.Array):
# The order of devices may change, so we need to reshard if necessary.
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
# scenario. Revise the logic later. A distributed shutdown barrier inside
# the XLA program may be needed.
return jax.device_put(tok, jax.sharding.PositionalSharding(devices))
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = pxla.shard_args([s], [tok])[0]
self.current_tokens[eff] = sharded_tok

View File

@ -1155,13 +1155,25 @@ class ExecuteReplicated:
def _handle_token_bufs(self, token_bufs, sharded_token):
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
# token buffer (as a singleton list).
# token buffers.
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
for i, device in enumerate(self._local_devices):
dispatch.runtime_tokens.set_output_runtime_token(
device, sharded_token.get_token(i))
for eff, token_buf in zip(self.ordered_effects, token_bufs):
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
assert len(token_buf) > 0
if len(token_buf) == 1:
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
else:
token_devices = []
for token in token_buf:
assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding)
token_devices.append(token.sharding._device_assignment[0])
s = sharding_impls.PositionalSharding(token_devices)
global_token_array = jax.make_array_from_single_device_arrays(
(0,), s, token_buf
)
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
@profiler.annotate_function
def __call__(self, *args):