Remove dispatch.result_handlers since they are not used.

PiperOrigin-RevId: 517456171
This commit is contained in:
Yash Katariya 2023-03-17 11:01:36 -07:00 committed by jax authors
parent 706549a270
commit 6d0189e810
3 changed files with 31 additions and 209 deletions

View File

@ -20,9 +20,8 @@ import contextlib
from functools import partial
import itertools
import time
from typing import (
Any, Callable, Dict, Iterator, Optional, Protocol,
Sequence, Set, Tuple, List, Type, Union, NamedTuple)
from typing import (Any, Callable, Dict, Iterator, Optional,
Sequence, Set, Tuple, List, Union, NamedTuple)
import logging
import os
import re
@ -416,180 +415,6 @@ def initial_style_primitive_replicas(params):
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
# Argument and result handlers
num_buffers_handlers: Dict[Type[core.AbstractValue],
Callable[[core.AbstractValue], int]] = {}
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
"""Returns the number of buffers in the runtime representation of `aval`.
In general this may differ from the number of buffers in the compiler-IR
representation of the same value.
"""
try:
return num_buffers_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
num_buffers_handlers[core.AbstractToken] = lambda _: 1
num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.DShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
def _input_handler(backend: Backend,
in_type: Optional[pe.InputType],
out_type: Optional[pe.OutputType],
) -> Optional[Callable]:
if in_type is None:
assert out_type is None
return None
in_avals, which_explicit = util.unzip2(in_type)
# Check whether we actually need an input_handler.
needs_implicit = which_explicit and not all(which_explicit)
needs_out_handling = any(type(d) is core.InDBIdx for a, _ in out_type or []
if type(a) is core.DShapedArray for d in a.shape)
if not needs_implicit and not needs_out_handling:
return None
assert config.jax_dynamic_shapes
# Precompute how to grab implicit inputs from explicit inputs' axis sizes.
which_explicit = which_explicit or (True,) * len(in_avals)
implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex}
implicit_args_from_axes: List[Tuple[int, int, int]] = []
for arg_idx, aval in enumerate(in_avals):
if isinstance(aval, core.DShapedArray):
for axis_idx, d in enumerate(aval.shape):
if isinstance(d, core.DBIdx) and d.val in implicit_idxs:
implicit_args_from_axes.append((d.val, arg_idx, axis_idx))
assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs
# Precompute which input values are needed for output types.
inputs_needed_for_out_types = out_type and [
d.val for aval, _ in out_type if type(aval) is core.DShapedArray # type: ignore
for d in aval.shape if type(d) is core.InDBIdx]
def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]:
if needs_implicit:
# Build full argument list, leaving Nones for implicit arguments.
explicit_args_ = iter(explicit_args)
args = [next(explicit_args_) if ex else None for ex in which_explicit]
assert next(explicit_args_, None) is None
# Populate implicit arguments.
for i, j, k in implicit_args_from_axes:
if args[i] is None:
args[i] = args[j].shape[k] # type: ignore
else:
if args[i] != args[j].shape[k]:
raise Exception("inconsistent argument axis sizes for type")
else:
args = list(explicit_args)
if needs_out_handling:
# Make a list of inputs needed by output types, leaving unneeded as None.
out_type_env = [None] * len(args)
for i in inputs_needed_for_out_types or []:
out_type_env[i] = args[i]
else:
out_type_env = None # type: ignore
return tuple(args), out_type_env and tuple(out_type_env) # type: ignore
return elaborate
def _result_handler(backend: Backend,
sticky_device: Optional[Device],
out_type: pe.OutputType,
) -> Callable:
out_avals, kept_outputs = util.unzip2(out_type)
handlers = map(partial(aval_to_result_handler, sticky_device), out_avals)
dyn_outs = any(type(aval) is core.DShapedArray and
any(type(d) in (core.InDBIdx, core.OutDBIdx) for d in aval.shape)
for aval in out_avals)
if not dyn_outs:
return SimpleResultHandler(handlers)
assert config.jax_dynamic_shapes
def result_handler(input_env, lists_of_bufs):
results = []
for handler, bufs in unsafe_zip(handlers, lists_of_bufs):
results.append(handler((input_env, results), *bufs))
return [r for r, keep in unsafe_zip(results, kept_outputs) if keep]
return result_handler
class SimpleResultHandler:
handlers: Sequence[ResultHandler]
def __init__(self, handlers): self.handlers = handlers
def __iter__(self): return iter(self.handlers)
def __len__(self): return len(self.handlers)
def __call__(self, env, lists_of_bufs):
return tuple(h(env, *bs) for h, bs in zip(self.handlers, lists_of_bufs))
def maybe_create_array_from_da(buf, aval, device):
return array.ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
committed=(device is not None), _skip_checks=True)
if MYPY:
ResultHandler = Any
else:
class ResultHandler(Protocol):
def __call__(self, env: Optional[Sequence[Any]], *args: xc.Buffer) -> Any:
"""Boxes raw buffers into their user-facing representation."""
def aval_to_result_handler(sticky_device: Optional[Device],
aval: core.AbstractValue) -> ResultHandler:
try:
return result_handlers[type(aval)](sticky_device, aval)
except KeyError as err:
raise TypeError(f"No result handler for type: {type(aval)}") from err
def array_result_handler(sticky_device: Optional[Device],
aval: core.ShapedArray):
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
aval = core.raise_to_shaped(aval)
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.result_handler(sticky_device, aval)
handler = lambda _, b: maybe_create_array_from_da(b, aval, sticky_device)
handler.args = aval, sticky_device # for C++ dispatch path in api.py
return handler
def dynamic_array_result_handler(sticky_device: Optional[Device],
aval: core.DShapedArray):
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
else:
return partial(_dynamic_array_result_handler, sticky_device, aval)
def _dynamic_array_result_handler(sticky_device, aval, env, buf):
in_env, out_env = env or (None, None)
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
if all(type(d) is int for d in shape) and type(aval.dtype) is not core.bint:
aval = core.ShapedArray(tuple(shape), aval.dtype)
return maybe_create_array_from_da(buf, aval, sticky_device)
else:
pad_shape = [d.dtype.bound if _is_bint_axis_size(d) else d for d in shape]
buf_dtype = (aval.dtype if not core.is_opaque_dtype(aval.dtype) else
aval.dtype._rules.physical_avals(aval)[0])
buf_aval = core.ShapedArray(tuple(pad_shape), buf_dtype, aval.weak_type)
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
return core.DArray(aval.update(shape=tuple(shape)), data)
result_handlers: Dict[
Type[core.AbstractValue],
Callable[[Optional[Device], Any], ResultHandler]] = {}
result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.DShapedArray] = dynamic_array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
def needs_check_special():
return config.jax_debug_infs or config.jax_debug_nans

View File

@ -13,12 +13,12 @@
# limitations under the License.
from jax import numpy as jnp
from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
SUPPORTED_DTYPES = frozenset({
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
@ -40,16 +40,24 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
undefined behavior if the DLPack consumer writes to a buffer that JAX
owns.
"""
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, got {}"
.format(type(x)))
if isinstance(x, array.ArrayImpl):
assert len(x._arrays) == 1
buf = x._arrays[0]
if xla_extension_version >= 140:
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
else:
buf = x.device_buffer
return xla_client._xla.buffer_to_dlpack_managed_tensor(
buf, take_ownership=take_ownership)
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, "
f"got {type(x)}")
if isinstance(x, array.ArrayImpl):
assert len(x._arrays) == 1
buf = x._arrays[0]
else:
buf = x.device_buffer
return xla_client._xla.buffer_to_dlpack_managed_tensor(
buf, take_ownership=take_ownership)
def from_dlpack(dlpack):
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
@ -72,13 +80,14 @@ def from_dlpack(dlpack):
except RuntimeError:
gpu_backend = None
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)
if isinstance(buf, array.ArrayImpl):
aval = buf.aval
if xla_extension_version >= 140:
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
else:
xla_shape = buf.xla_shape()
assert not xla_shape.is_tuple()
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
return jnp.asarray( # asarray ensures dtype canonicalization
dispatch.maybe_create_array_from_da(buf, aval, buf.device()))
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)
if isinstance(buf, array.ArrayImpl):
return jnp.asarray(buf) # asarray ensures dtype canonicalization
else:
return jnp.asarray(array._single_device_array_from_buf(
buf, committed=buf.device() is not None))

View File

@ -24,8 +24,6 @@ from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -112,12 +110,6 @@ class AbstractSparseArray(core.ShapedArray):
class ConcreteSparseArray(AbstractSparseArray):
pass
def sparse_array_result_handler(device, aval):
def build_sparse_array(_, data_buf, indices_buf):
data = device_array.make_device_array(aval.data_aval, device, data_buf)
indices = device_array.make_device_array(aval.indices_aval, device, indices_buf)
return SparseArray(aval, data, indices)
return build_sparse_array
def sparse_array_shape_handler(a):
return (
@ -130,8 +122,6 @@ core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
dispatch.result_handlers[AbstractSparseArray] = sparse_array_result_handler
dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
def sparse_array_mlir_type_handler(a):
@ -267,8 +257,6 @@ core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty()
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
dispatch.result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
dispatch.num_buffers_handlers[AbstractEmpty] = lambda _: 0
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()