mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add PyArrayResultHandler which behaves like
functools.partial(jax.arrays.ArrayImpl) with the added benefit that the new PyExecuteResults type can explode directly into ArrayImpls if passed to explode_with_handlers(). Note that this also helps with deprecating PyBuffer as the fastpath does not need to call the PyBuffer constructor. PiperOrigin-RevId: 512788757
This commit is contained in:
parent
586fe8d552
commit
eef3e69c61
@ -28,6 +28,7 @@ from jax._src import dtypes
|
||||
from jax._src.config import config
|
||||
from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import api
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax.interpreters import mlir
|
||||
@ -667,6 +668,10 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
|
||||
if core.is_opaque_dtype(global_aval.dtype):
|
||||
return global_aval.dtype._rules.global_sharded_result_handler(
|
||||
global_aval, out_sharding, committed, is_out_sharding_from_xla)
|
||||
if xla_extension_version >= 131:
|
||||
return xc.array_result_handler(
|
||||
global_aval, out_sharding, committed=committed, _skip_checks=True
|
||||
)
|
||||
return lambda bufs: ArrayImpl(global_aval, out_sharding, bufs,
|
||||
committed=committed, _skip_checks=True)
|
||||
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
@ -681,6 +686,10 @@ def _array_local_result_handler(aval, sharding, indices):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.local_sharded_result_handler(
|
||||
aval, sharding, indices)
|
||||
if xla_extension_version >= 131:
|
||||
return xc.array_result_handler(
|
||||
aval, sharding, committed=True, _skip_checks=True
|
||||
)
|
||||
return lambda bufs: ArrayImpl(aval, sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
|
@ -74,6 +74,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -2103,38 +2104,70 @@ class ExecuteReplicated:
|
||||
self.has_host_callbacks = has_host_callbacks
|
||||
self.kept_var_idx = kept_var_idx
|
||||
|
||||
def _call_with_tokens(self, input_bufs):
|
||||
def _add_tokens_to_inputs(self, input_bufs):
|
||||
if self.ordered_effects:
|
||||
device, = self._local_devices
|
||||
tokens = [list(dispatch.runtime_tokens.get_token(eff, device))
|
||||
for eff in self.ordered_effects]
|
||||
input_bufs = [*tokens, *input_bufs]
|
||||
num_output_tokens = len(self.ordered_effects)
|
||||
out_bufs, sharded_token = (
|
||||
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
|
||||
input_bufs))
|
||||
token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens])
|
||||
return input_bufs
|
||||
|
||||
def _handle_token_bufs(self, token_bufs, sharded_token):
|
||||
for i, device in enumerate(self._local_devices):
|
||||
dispatch.runtime_tokens.set_output_runtime_token(
|
||||
device, sharded_token.get_token(i))
|
||||
for eff, token_buf in zip(self.ordered_effects, token_bufs):
|
||||
dispatch.runtime_tokens.update_token(eff, token_buf)
|
||||
|
||||
def _call_with_tokens(self, input_bufs):
|
||||
input_bufs = self._add_tokens_to_inputs(input_bufs)
|
||||
out_bufs, sharded_token = (
|
||||
self.xla_executable.execute_sharded_on_local_devices_with_tokens(
|
||||
input_bufs
|
||||
)
|
||||
)
|
||||
num_output_tokens = len(self.ordered_effects)
|
||||
token_bufs, out_bufs = util.split_list(out_bufs, [num_output_tokens])
|
||||
self._handle_token_bufs(token_bufs, sharded_token)
|
||||
return out_bufs
|
||||
|
||||
@profiler.annotate_function
|
||||
def __call__(self, *args):
|
||||
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
|
||||
input_bufs = self.in_handler(args)
|
||||
if (self.ordered_effects or self.has_unordered_effects or
|
||||
self.has_host_callbacks):
|
||||
out_bufs = self._call_with_tokens(input_bufs)
|
||||
if xla_extension_version >= 131:
|
||||
if (self.ordered_effects or self.has_unordered_effects
|
||||
or self.has_host_callbacks):
|
||||
input_bufs = self._add_tokens_to_inputs(input_bufs)
|
||||
results = self.xla_executable.execute_sharded(
|
||||
input_bufs, with_tokens=True
|
||||
)
|
||||
self._handle_token_bufs(
|
||||
results.disassemble_prefix_into_single_device_arrays(
|
||||
len(self.ordered_effects)
|
||||
),
|
||||
results.consume_token(),
|
||||
)
|
||||
else:
|
||||
results = self.xla_executable.execute_sharded(input_bufs)
|
||||
if dispatch.needs_check_special():
|
||||
out_arrays = results.disassemble_into_single_device_arrays()
|
||||
for arrays in out_arrays:
|
||||
dispatch.check_special(self.name, arrays)
|
||||
return self.out_handler(out_arrays)
|
||||
return results.consume_with_handlers(self.out_handler.handlers)
|
||||
else:
|
||||
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
|
||||
input_bufs)
|
||||
if dispatch.needs_check_special():
|
||||
for bufs in out_bufs:
|
||||
dispatch.check_special(self.name, bufs)
|
||||
return self.out_handler(out_bufs)
|
||||
if (self.ordered_effects or self.has_unordered_effects
|
||||
or self.has_host_callbacks):
|
||||
out_bufs = self._call_with_tokens(input_bufs)
|
||||
else:
|
||||
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
|
||||
input_bufs
|
||||
)
|
||||
if dispatch.needs_check_special():
|
||||
for bufs in out_bufs:
|
||||
dispatch.check_special(self.name, bufs)
|
||||
return self.out_handler(out_bufs)
|
||||
|
||||
|
||||
xla_pmap_p = core.MapPrimitive('xla_pmap')
|
||||
|
Loading…
x
Reference in New Issue
Block a user