mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Remove dispatch.result_handlers since they are not used.
PiperOrigin-RevId: 517456171
This commit is contained in:
parent
706549a270
commit
6d0189e810
@ -20,9 +20,8 @@ import contextlib
|
||||
from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
from typing import (
|
||||
Any, Callable, Dict, Iterator, Optional, Protocol,
|
||||
Sequence, Set, Tuple, List, Type, Union, NamedTuple)
|
||||
from typing import (Any, Callable, Dict, Iterator, Optional,
|
||||
Sequence, Set, Tuple, List, Union, NamedTuple)
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -416,180 +415,6 @@ def initial_style_primitive_replicas(params):
|
||||
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
|
||||
|
||||
|
||||
# Argument and result handlers
|
||||
|
||||
num_buffers_handlers: Dict[Type[core.AbstractValue],
|
||||
Callable[[core.AbstractValue], int]] = {}
|
||||
|
||||
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
|
||||
"""Returns the number of buffers in the runtime representation of `aval`.
|
||||
|
||||
In general this may differ from the number of buffers in the compiler-IR
|
||||
representation of the same value.
|
||||
"""
|
||||
try:
|
||||
return num_buffers_handlers[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
|
||||
|
||||
num_buffers_handlers[core.AbstractToken] = lambda _: 1
|
||||
num_buffers_handlers[core.ShapedArray] = lambda _: 1
|
||||
num_buffers_handlers[core.DShapedArray] = lambda _: 1
|
||||
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
|
||||
|
||||
|
||||
def _input_handler(backend: Backend,
|
||||
in_type: Optional[pe.InputType],
|
||||
out_type: Optional[pe.OutputType],
|
||||
) -> Optional[Callable]:
|
||||
if in_type is None:
|
||||
assert out_type is None
|
||||
return None
|
||||
in_avals, which_explicit = util.unzip2(in_type)
|
||||
# Check whether we actually need an input_handler.
|
||||
needs_implicit = which_explicit and not all(which_explicit)
|
||||
needs_out_handling = any(type(d) is core.InDBIdx for a, _ in out_type or []
|
||||
if type(a) is core.DShapedArray for d in a.shape)
|
||||
|
||||
if not needs_implicit and not needs_out_handling:
|
||||
return None
|
||||
assert config.jax_dynamic_shapes
|
||||
|
||||
# Precompute how to grab implicit inputs from explicit inputs' axis sizes.
|
||||
which_explicit = which_explicit or (True,) * len(in_avals)
|
||||
implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex}
|
||||
implicit_args_from_axes: List[Tuple[int, int, int]] = []
|
||||
for arg_idx, aval in enumerate(in_avals):
|
||||
if isinstance(aval, core.DShapedArray):
|
||||
for axis_idx, d in enumerate(aval.shape):
|
||||
if isinstance(d, core.DBIdx) and d.val in implicit_idxs:
|
||||
implicit_args_from_axes.append((d.val, arg_idx, axis_idx))
|
||||
assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs
|
||||
|
||||
# Precompute which input values are needed for output types.
|
||||
inputs_needed_for_out_types = out_type and [
|
||||
d.val for aval, _ in out_type if type(aval) is core.DShapedArray # type: ignore
|
||||
for d in aval.shape if type(d) is core.InDBIdx]
|
||||
|
||||
def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]:
|
||||
if needs_implicit:
|
||||
# Build full argument list, leaving Nones for implicit arguments.
|
||||
explicit_args_ = iter(explicit_args)
|
||||
args = [next(explicit_args_) if ex else None for ex in which_explicit]
|
||||
assert next(explicit_args_, None) is None
|
||||
# Populate implicit arguments.
|
||||
for i, j, k in implicit_args_from_axes:
|
||||
if args[i] is None:
|
||||
args[i] = args[j].shape[k] # type: ignore
|
||||
else:
|
||||
if args[i] != args[j].shape[k]:
|
||||
raise Exception("inconsistent argument axis sizes for type")
|
||||
else:
|
||||
args = list(explicit_args)
|
||||
|
||||
if needs_out_handling:
|
||||
# Make a list of inputs needed by output types, leaving unneeded as None.
|
||||
out_type_env = [None] * len(args)
|
||||
for i in inputs_needed_for_out_types or []:
|
||||
out_type_env[i] = args[i]
|
||||
else:
|
||||
out_type_env = None # type: ignore
|
||||
|
||||
return tuple(args), out_type_env and tuple(out_type_env) # type: ignore
|
||||
return elaborate
|
||||
|
||||
def _result_handler(backend: Backend,
|
||||
sticky_device: Optional[Device],
|
||||
out_type: pe.OutputType,
|
||||
) -> Callable:
|
||||
out_avals, kept_outputs = util.unzip2(out_type)
|
||||
handlers = map(partial(aval_to_result_handler, sticky_device), out_avals)
|
||||
dyn_outs = any(type(aval) is core.DShapedArray and
|
||||
any(type(d) in (core.InDBIdx, core.OutDBIdx) for d in aval.shape)
|
||||
for aval in out_avals)
|
||||
if not dyn_outs:
|
||||
return SimpleResultHandler(handlers)
|
||||
assert config.jax_dynamic_shapes
|
||||
|
||||
def result_handler(input_env, lists_of_bufs):
|
||||
results = []
|
||||
for handler, bufs in unsafe_zip(handlers, lists_of_bufs):
|
||||
results.append(handler((input_env, results), *bufs))
|
||||
return [r for r, keep in unsafe_zip(results, kept_outputs) if keep]
|
||||
return result_handler
|
||||
|
||||
class SimpleResultHandler:
|
||||
handlers: Sequence[ResultHandler]
|
||||
def __init__(self, handlers): self.handlers = handlers
|
||||
def __iter__(self): return iter(self.handlers)
|
||||
def __len__(self): return len(self.handlers)
|
||||
def __call__(self, env, lists_of_bufs):
|
||||
return tuple(h(env, *bs) for h, bs in zip(self.handlers, lists_of_bufs))
|
||||
|
||||
|
||||
def maybe_create_array_from_da(buf, aval, device):
|
||||
return array.ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
|
||||
committed=(device is not None), _skip_checks=True)
|
||||
|
||||
|
||||
if MYPY:
|
||||
ResultHandler = Any
|
||||
else:
|
||||
class ResultHandler(Protocol):
|
||||
def __call__(self, env: Optional[Sequence[Any]], *args: xc.Buffer) -> Any:
|
||||
"""Boxes raw buffers into their user-facing representation."""
|
||||
|
||||
def aval_to_result_handler(sticky_device: Optional[Device],
|
||||
aval: core.AbstractValue) -> ResultHandler:
|
||||
try:
|
||||
return result_handlers[type(aval)](sticky_device, aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No result handler for type: {type(aval)}") from err
|
||||
|
||||
def array_result_handler(sticky_device: Optional[Device],
|
||||
aval: core.ShapedArray):
|
||||
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
|
||||
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
|
||||
aval = core.raise_to_shaped(aval)
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.result_handler(sticky_device, aval)
|
||||
handler = lambda _, b: maybe_create_array_from_da(b, aval, sticky_device)
|
||||
handler.args = aval, sticky_device # for C++ dispatch path in api.py
|
||||
return handler
|
||||
|
||||
def dynamic_array_result_handler(sticky_device: Optional[Device],
|
||||
aval: core.DShapedArray):
|
||||
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
|
||||
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
|
||||
else:
|
||||
return partial(_dynamic_array_result_handler, sticky_device, aval)
|
||||
|
||||
def _dynamic_array_result_handler(sticky_device, aval, env, buf):
|
||||
in_env, out_env = env or (None, None)
|
||||
shape = [in_env[d.val] if type(d) is core.InDBIdx else
|
||||
out_env[d.val] if type(d) is core.OutDBIdx else d
|
||||
for d in aval.shape]
|
||||
if all(type(d) is int for d in shape) and type(aval.dtype) is not core.bint:
|
||||
aval = core.ShapedArray(tuple(shape), aval.dtype)
|
||||
return maybe_create_array_from_da(buf, aval, sticky_device)
|
||||
else:
|
||||
pad_shape = [d.dtype.bound if _is_bint_axis_size(d) else d for d in shape]
|
||||
buf_dtype = (aval.dtype if not core.is_opaque_dtype(aval.dtype) else
|
||||
aval.dtype._rules.physical_avals(aval)[0])
|
||||
buf_aval = core.ShapedArray(tuple(pad_shape), buf_dtype, aval.weak_type)
|
||||
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
|
||||
return core.DArray(aval.update(shape=tuple(shape)), data)
|
||||
|
||||
|
||||
result_handlers: Dict[
|
||||
Type[core.AbstractValue],
|
||||
Callable[[Optional[Device], Any], ResultHandler]] = {}
|
||||
result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
|
||||
result_handlers[core.ShapedArray] = array_result_handler
|
||||
result_handlers[core.DShapedArray] = dynamic_array_result_handler
|
||||
result_handlers[core.ConcreteArray] = array_result_handler
|
||||
|
||||
|
||||
def needs_check_special():
|
||||
return config.jax_debug_infs or config.jax_debug_nans
|
||||
|
||||
|
@ -13,12 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import array
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
SUPPORTED_DTYPES = frozenset({
|
||||
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
|
||||
@ -40,16 +40,24 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
|
||||
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
||||
owns.
|
||||
"""
|
||||
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, got {}"
|
||||
.format(type(x)))
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
assert len(x._arrays) == 1
|
||||
buf = x._arrays[0]
|
||||
if xla_extension_version >= 140:
|
||||
if not isinstance(x, array.ArrayImpl):
|
||||
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
||||
f"got {type(x)}")
|
||||
assert len(x.devices()) == 1
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
|
||||
else:
|
||||
buf = x.device_buffer
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
buf, take_ownership=take_ownership)
|
||||
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, "
|
||||
f"got {type(x)}")
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
assert len(x._arrays) == 1
|
||||
buf = x._arrays[0]
|
||||
else:
|
||||
buf = x.device_buffer
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
buf, take_ownership=take_ownership)
|
||||
|
||||
def from_dlpack(dlpack):
|
||||
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
|
||||
@ -72,13 +80,14 @@ def from_dlpack(dlpack):
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)
|
||||
if isinstance(buf, array.ArrayImpl):
|
||||
aval = buf.aval
|
||||
if xla_extension_version >= 140:
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend))
|
||||
else:
|
||||
xla_shape = buf.xla_shape()
|
||||
assert not xla_shape.is_tuple()
|
||||
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
||||
return jnp.asarray( # asarray ensures dtype canonicalization
|
||||
dispatch.maybe_create_array_from_da(buf, aval, buf.device()))
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)
|
||||
if isinstance(buf, array.ArrayImpl):
|
||||
return jnp.asarray(buf) # asarray ensures dtype canonicalization
|
||||
else:
|
||||
return jnp.asarray(array._single_device_array_from_buf(
|
||||
buf, committed=buf.device() is not None))
|
||||
|
@ -24,8 +24,6 @@ from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -112,12 +110,6 @@ class AbstractSparseArray(core.ShapedArray):
|
||||
class ConcreteSparseArray(AbstractSparseArray):
|
||||
pass
|
||||
|
||||
def sparse_array_result_handler(device, aval):
|
||||
def build_sparse_array(_, data_buf, indices_buf):
|
||||
data = device_array.make_device_array(aval.data_aval, device, data_buf)
|
||||
indices = device_array.make_device_array(aval.indices_aval, device, indices_buf)
|
||||
return SparseArray(aval, data, indices)
|
||||
return build_sparse_array
|
||||
|
||||
def sparse_array_shape_handler(a):
|
||||
return (
|
||||
@ -130,8 +122,6 @@ core.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
|
||||
dispatch.result_handlers[AbstractSparseArray] = sparse_array_result_handler
|
||||
dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2
|
||||
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
|
||||
|
||||
def sparse_array_mlir_type_handler(a):
|
||||
@ -267,8 +257,6 @@ core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty()
|
||||
core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
|
||||
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
|
||||
dispatch.result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
|
||||
dispatch.num_buffers_handlers[AbstractEmpty] = lambda _: 0
|
||||
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user