From eef3e69c61597fa691a329c876725ca436525b08 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 27 Feb 2023 18:26:12 -0800 Subject: [PATCH] Add PyArrayResultHandler which behaves like functools.partial(jax.arrays.ArrayImpl) with the added benefit that the new PyExecuteResults type can explode directly into ArrayImpls if passed to explode_with_handlers(). Note that this also helps with deprecating PyBuffer as the fastpath does not need to call the PyBuffer constructor. PiperOrigin-RevId: 512788757 --- jax/_src/array.py | 9 +++++ jax/_src/interpreters/pxla.py | 63 ++++++++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index f7d247cd7..371f608f7 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -28,6 +28,7 @@ from jax._src import dtypes from jax._src.config import config from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src import api from jax._src.typing import ArrayLike from jax.interpreters import mlir @@ -667,6 +668,10 @@ def _array_global_result_handler(global_aval, out_sharding, committed, if core.is_opaque_dtype(global_aval.dtype): return global_aval.dtype._rules.global_sharded_result_handler( global_aval, out_sharding, committed, is_out_sharding_from_xla) + if xla_extension_version >= 131: + return xc.array_result_handler( + global_aval, out_sharding, committed=committed, _skip_checks=True + ) return lambda bufs: ArrayImpl(global_aval, out_sharding, bufs, committed=committed, _skip_checks=True) pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler @@ -681,6 +686,10 @@ def _array_local_result_handler(aval, sharding, indices): if core.is_opaque_dtype(aval.dtype): return aval.dtype._rules.local_sharded_result_handler( aval, sharding, indices) + if xla_extension_version >= 131: + return xc.array_result_handler( + aval, sharding, committed=True, _skip_checks=True + ) return lambda bufs: ArrayImpl(aval, sharding, bufs, committed=True, _skip_checks=True) pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6d380b34a..cd5f9f5ea 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -74,6 +74,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib import pmap_lib from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -2103,38 +2104,70 @@ class ExecuteReplicated: self.has_host_callbacks = has_host_callbacks self.kept_var_idx = kept_var_idx - def _call_with_tokens(self, input_bufs): + def _add_tokens_to_inputs(self, input_bufs): if self.ordered_effects: device, = self._local_devices tokens = [list(dispatch.runtime_tokens.get_token(eff, device)) for eff in self.ordered_effects] input_bufs = [*tokens, *input_bufs] - num_output_tokens = len(self.ordered_effects) - out_bufs, sharded_token = ( - self.xla_executable.execute_sharded_on_local_devices_with_tokens( - input_bufs)) - token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens]) + return input_bufs + + def _handle_token_bufs(self, token_bufs, sharded_token): for i, device in enumerate(self._local_devices): dispatch.runtime_tokens.set_output_runtime_token( device, sharded_token.get_token(i)) for eff, token_buf in zip(self.ordered_effects, token_bufs): dispatch.runtime_tokens.update_token(eff, token_buf) + + def _call_with_tokens(self, input_bufs): + input_bufs = self._add_tokens_to_inputs(input_bufs) + out_bufs, sharded_token = ( + self.xla_executable.execute_sharded_on_local_devices_with_tokens( + input_bufs + ) + ) + num_output_tokens = len(self.ordered_effects) + token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens]) + self._handle_token_bufs(token_bufs, sharded_token) return out_bufs @profiler.annotate_function def __call__(self, *args): args = [x for i, x in enumerate(args) if i in self.kept_var_idx] input_bufs = self.in_handler(args) - if (self.ordered_effects or self.has_unordered_effects or - self.has_host_callbacks): - out_bufs = self._call_with_tokens(input_bufs) + if xla_extension_version >= 131: + if (self.ordered_effects or self.has_unordered_effects + or self.has_host_callbacks): + input_bufs = self._add_tokens_to_inputs(input_bufs) + results = self.xla_executable.execute_sharded( + input_bufs, with_tokens=True + ) + self._handle_token_bufs( + results.disassemble_prefix_into_single_device_arrays( + len(self.ordered_effects) + ), + results.consume_token(), + ) + else: + results = self.xla_executable.execute_sharded(input_bufs) + if dispatch.needs_check_special(): + out_arrays = results.disassemble_into_single_device_arrays() + for arrays in out_arrays: + dispatch.check_special(self.name, arrays) + return self.out_handler(out_arrays) + return results.consume_with_handlers(self.out_handler.handlers) else: - out_bufs = self.xla_executable.execute_sharded_on_local_devices( - input_bufs) - if dispatch.needs_check_special(): - for bufs in out_bufs: - dispatch.check_special(self.name, bufs) - return self.out_handler(out_bufs) + if (self.ordered_effects or self.has_unordered_effects + or self.has_host_callbacks): + out_bufs = self._call_with_tokens(input_bufs) + else: + out_bufs = self.xla_executable.execute_sharded_on_local_devices( + input_bufs + ) + if dispatch.needs_check_special(): + for bufs in out_bufs: + dispatch.check_special(self.name, bufs) + return self.out_handler(out_bufs) xla_pmap_p = core.MapPrimitive('xla_pmap')