Add PyArrayResultHandler which behaves like

functools.partial(jax.arrays.ArrayImpl) with the added benefit
that the new PyExecuteResults type can explode directly into
ArrayImpls if passed to explode_with_handlers().

Note that this also helps with deprecating PyBuffer as the fastpath
does not need to call the PyBuffer constructor.

PiperOrigin-RevId: 512788757
This commit is contained in:
Parker Schuh 2023-02-27 18:26:12 -08:00 committed by jax authors
parent 586fe8d552
commit eef3e69c61
2 changed files with 57 additions and 15 deletions

View File

@ -28,6 +28,7 @@ from jax._src import dtypes
from jax._src.config import config
from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import api
from jax._src.typing import ArrayLike
from jax.interpreters import mlir
@ -667,6 +668,10 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)
if xla_extension_version >= 131:
return xc.array_result_handler(
global_aval, out_sharding, committed=committed, _skip_checks=True
)
return lambda bufs: ArrayImpl(global_aval, out_sharding, bufs,
committed=committed, _skip_checks=True)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
@ -681,6 +686,10 @@ def _array_local_result_handler(aval, sharding, indices):
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
if xla_extension_version >= 131:
return xc.array_result_handler(
aval, sharding, committed=True, _skip_checks=True
)
return lambda bufs: ArrayImpl(aval, sharding, bufs, committed=True,
_skip_checks=True)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler

View File

@ -74,6 +74,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
@ -2103,38 +2104,70 @@ class ExecuteReplicated:
self.has_host_callbacks = has_host_callbacks
self.kept_var_idx = kept_var_idx
def _call_with_tokens(self, input_bufs):
def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
device, = self._local_devices
tokens = [list(dispatch.runtime_tokens.get_token(eff, device))
for eff in self.ordered_effects]
input_bufs = [*tokens, *input_bufs]
num_output_tokens = len(self.ordered_effects)
out_bufs, sharded_token = (
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
input_bufs))
token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens])
return input_bufs
def _handle_token_bufs(self, token_bufs, sharded_token):
for i, device in enumerate(self._local_devices):
dispatch.runtime_tokens.set_output_runtime_token(
device, sharded_token.get_token(i))
for eff, token_buf in zip(self.ordered_effects, token_bufs):
dispatch.runtime_tokens.update_token(eff, token_buf)
def _call_with_tokens(self, input_bufs):
input_bufs = self._add_tokens_to_inputs(input_bufs)
out_bufs, sharded_token = (
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
input_bufs
)
)
num_output_tokens = len(self.ordered_effects)
token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens])
self._handle_token_bufs(token_bufs, sharded_token)
return out_bufs
@profiler.annotate_function
def __call__(self, *args):
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
input_bufs = self.in_handler(args)
if (self.ordered_effects or self.has_unordered_effects or
self.has_host_callbacks):
out_bufs = self._call_with_tokens(input_bufs)
if xla_extension_version >= 131:
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(
input_bufs, with_tokens=True
)
self._handle_token_bufs(
results.disassemble_prefix_into_single_device_arrays(
len(self.ordered_effects)
),
results.consume_token(),
)
else:
results = self.xla_executable.execute_sharded(input_bufs)
if dispatch.needs_check_special():
out_arrays = results.disassemble_into_single_device_arrays()
for arrays in out_arrays:
dispatch.check_special(self.name, arrays)
return self.out_handler(out_arrays)
return results.consume_with_handlers(self.out_handler.handlers)
else:
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
input_bufs)
if dispatch.needs_check_special():
for bufs in out_bufs:
dispatch.check_special(self.name, bufs)
return self.out_handler(out_bufs)
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
out_bufs = self._call_with_tokens(input_bufs)
else:
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
input_bufs
)
if dispatch.needs_check_special():
for bufs in out_bufs:
dispatch.check_special(self.name, bufs)
return self.out_handler(out_bufs)
xla_pmap_p = core.MapPrimitive('xla_pmap')