diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index d62d57ea2..837e1b9bd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -108,7 +108,7 @@ def simple_impl(prim): RuntimeToken = Any class RuntimeTokenSet(threading.local): - """See docstring for effect.py module for the calling convention for tokens.""" + """See docstring for effects.py module for the calling convention for tokens.""" # For each ordered effect, the token returned by the last dispatched # computation, sharded over the devices in that computation. @@ -125,6 +125,16 @@ class RuntimeTokenSet(threading.local): def get_token_input(self, eff: core.Effect, devices: list[Device]) -> jax.Array: tok = self.current_tokens.get(eff, np.zeros(0, np.bool_)) + + if isinstance(tok, jax.Array): + # The order of devices may change, so we need to reshard if necessary. + # TODO(yueshengys): This might still be buggy in a multi-process SPMD + # scenario. Revise the logic later. A distributed shutdown barrier inside + # the XLA program may be needed. + return jax.device_put(tok, jax.sharding.PositionalSharding(devices)) + + # We only use replicated sharding for the first time when the token for the + # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) sharded_tok = pxla.shard_args([s], [tok])[0] self.current_tokens[eff] = sharded_tok diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6fe95dbbe..3e4fc344f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1155,13 +1155,25 @@ class ExecuteReplicated: def _handle_token_bufs(self, token_bufs, sharded_token): # token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned - # token buffer (as a singleton list). + # token buffers. # sharded_token: ShardedToken, containing the RuntimeTokens for each device for i, device in enumerate(self._local_devices): dispatch.runtime_tokens.set_output_runtime_token( device, sharded_token.get_token(i)) for eff, token_buf in zip(self.ordered_effects, token_bufs): - dispatch.runtime_tokens.set_token_result(eff, token_buf[0]) + assert len(token_buf) > 0 + if len(token_buf) == 1: + dispatch.runtime_tokens.set_token_result(eff, token_buf[0]) + else: + token_devices = [] + for token in token_buf: + assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding) + token_devices.append(token.sharding._device_assignment[0]) + s = sharding_impls.PositionalSharding(token_devices) + global_token_array = jax.make_array_from_single_device_arrays( + (0,), s, token_buf + ) + dispatch.runtime_tokens.set_token_result(eff, global_token_array) @profiler.annotate_function def __call__(self, *args):