mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge most of the MLIR JIT dispatch logic into the common primitive and JIT computation path.
Change the representation of both units and tokens at the runtime level to be a single buffer with shape pred[0]. While the MLIR lowering is happy to have a non 1:1 mapping between avals and IR values, the XLA lowering is not, so until we remove the XLA lowering it's easiest just to keep the mapping 1:1. PiperOrigin-RevId: 412957231
This commit is contained in:
parent
5084a12034
commit
12512cc96a
@ -18,6 +18,7 @@ from functools import partial
|
||||
import itertools
|
||||
from typing import (
|
||||
Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union)
|
||||
from typing_extensions import Protocol
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -28,6 +29,7 @@ from jax import linear_util as lu
|
||||
import jax.interpreters.ad as ad
|
||||
import jax.interpreters.batching as batching
|
||||
import jax.interpreters.masking as masking
|
||||
import jax.interpreters.mlir as mlir
|
||||
import jax.interpreters.xla as xla
|
||||
import jax.interpreters.partial_eval as pe
|
||||
from jax.errors import UnexpectedTracerError
|
||||
@ -42,6 +44,7 @@ from jax._src import traceback_util
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
MYPY = False # Are we currently type checking with mypy?
|
||||
|
||||
xe = xc._xla
|
||||
|
||||
@ -71,9 +74,6 @@ def arg_spec(x: Any) -> ArgSpec:
|
||||
|
||||
def apply_primitive(prim, *args, **params):
|
||||
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
||||
if config.jax_enable_mlir:
|
||||
import jax.interpreters.mlir
|
||||
return jax.interpreters.mlir.apply_primitive(prim, *args, **params)
|
||||
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
|
||||
**params)
|
||||
return compiled_fun(*args)
|
||||
@ -123,12 +123,6 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
|
||||
|
||||
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
donated_invars, inline):
|
||||
if config.jax_enable_mlir:
|
||||
import jax.interpreters.mlir
|
||||
return jax.interpreters.mlir._xla_call_impl_mlir(
|
||||
fun, *args, device=device, backend=backend, name=name,
|
||||
donated_invars=donated_invars, inline=inline)
|
||||
|
||||
del inline # Only used at tracing time
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*unsafe_map(arg_spec, args))
|
||||
@ -220,31 +214,48 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
||||
"extra data movement anyway, so maybe you don't want it after all).")
|
||||
|
||||
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
|
||||
# pass long arg lists as tuple for TPU
|
||||
tuple_args = len(abstract_args) > 100
|
||||
axis_env = xla.AxisEnv(nreps, (), ())
|
||||
name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit'))
|
||||
module: Any
|
||||
if config.jax_enable_mlir:
|
||||
# TODO(b/203122001): implement buffer donation.
|
||||
assert not any(donated_invars), donated_invars
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
core.ClosedJaxpr(jaxpr, consts), backend.platform, axis_env, name_stack)
|
||||
else:
|
||||
# XLA HLO lowering path
|
||||
c = xc.XlaBuilder(f"jit_{fun.__name__}")
|
||||
xla_consts = xla._xla_consts(c, consts)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, abstract_args, tuple_args, donated_invars=donated_invars)
|
||||
platform = backend.platform
|
||||
ctx = xla.TranslationContext(c, platform, axis_env, name_stack)
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
# Replace tokens with a dummy array value, because the runtime cannot
|
||||
# handle token arguments.
|
||||
out_aval_lens = [len(xla.aval_to_xla_shapes(a)) for a in out_avals]
|
||||
out_nodes = util.flatten(
|
||||
[[xla._make_token_return_value(c)] if a is core.abstract_token
|
||||
else v
|
||||
for a, v in zip(out_avals, util.unflatten(out_nodes, out_aval_lens))])
|
||||
|
||||
c = xc.XlaBuilder(f"jit_{fun.__name__}")
|
||||
xla_consts = xla._xla_consts(c, consts)
|
||||
xla_args, donated_invars = xla._xla_callable_args(c, abstract_args, tuple_args,
|
||||
donated_invars=donated_invars)
|
||||
platform = backend.platform
|
||||
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(nreps, (), ()),
|
||||
xla.extend_name_stack(xla.wrap_name(name, 'jit')))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
# There is a non-zero cost to building an output tuple, particularly on TPU.
|
||||
# Avoid it if the output arity is 1.
|
||||
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
|
||||
if platform in ("gpu", "tpu"):
|
||||
donated_invars = xla.set_up_aliases(
|
||||
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
|
||||
if any(donated_invars):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(c.GetShape(a))
|
||||
for a, d in zip(xla_args, donated_invars) if d]
|
||||
warnings.warn("Some donated buffers were not usable: {}".format(
|
||||
", ".join(unused_donations)))
|
||||
built = c.build(output)
|
||||
# There is a non-zero cost to building an output tuple, particularly on TPU.
|
||||
# Avoid it if the output arity is 1.
|
||||
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
|
||||
if platform in ("gpu", "tpu"):
|
||||
donated_invars = xla.set_up_aliases(
|
||||
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
|
||||
if any(donated_invars):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(c.GetShape(a))
|
||||
for a, d in zip(xla_args, donated_invars) if d]
|
||||
warnings.warn("Some donated buffers were not usable: {}".format(
|
||||
", ".join(unused_donations)))
|
||||
module = c.build(output)
|
||||
return XlaComputation(
|
||||
name, built, False, donated_invars, nreps, device, backend, tuple_args,
|
||||
name, module, False, donated_invars, nreps, device, backend, tuple_args,
|
||||
abstract_args, out_avals, kept_var_idx)
|
||||
|
||||
|
||||
@ -369,27 +380,58 @@ def _xla_callable_device(nreps, backend, device, arg_devices):
|
||||
assert False # Unreachable given the error check in _xla_callable
|
||||
|
||||
|
||||
# Result handlers
|
||||
# Argument and result handlers
|
||||
|
||||
def aval_to_result_handler(device: Optional[Device],
|
||||
aval: core.AbstractValue) -> Callable:
|
||||
num_buffers_handlers: Dict[Type[core.AbstractValue],
|
||||
Callable[[core.AbstractValue], int]] = {}
|
||||
|
||||
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
|
||||
"""Returns the number of buffers in the runtime representation of `aval`.
|
||||
|
||||
In general this may differ from the number of buffers in the compiler-IR
|
||||
representation of the same value.
|
||||
"""
|
||||
try:
|
||||
return xla_result_handlers[type(aval)](device, aval)
|
||||
return num_buffers_handlers[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err
|
||||
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
|
||||
|
||||
def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
|
||||
# TODO(phawkins): use zero buffers to represent a unit.
|
||||
num_buffers_handlers[core.AbstractUnit] = lambda _: 1
|
||||
num_buffers_handlers[core.AbstractToken] = lambda _: 1
|
||||
num_buffers_handlers[core.ShapedArray] = lambda _: 1
|
||||
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
|
||||
|
||||
|
||||
if MYPY:
|
||||
ResultHandler = Any
|
||||
else:
|
||||
class ResultHandler(Protocol):
|
||||
def __call__(self, *args: xla.Buffer) -> Any:
|
||||
"""Boxes raw buffers into their user-facing representation."""
|
||||
|
||||
def aval_to_result_handler(sticky_device: Optional[Device],
|
||||
aval: core.AbstractValue) -> ResultHandler:
|
||||
try:
|
||||
return result_handlers[type(aval)](sticky_device, aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No result handler for type: {type(aval)}") from err
|
||||
|
||||
def array_result_handler(sticky_device: Optional[Device],
|
||||
aval: core.ShapedArray):
|
||||
if aval.dtype is dtypes.float0:
|
||||
return lambda _: np.zeros(aval.shape, dtypes.float0)
|
||||
return partial(device_array.make_device_array, core.raise_to_shaped(aval), device)
|
||||
return partial(device_array.make_device_array, core.raise_to_shaped(aval),
|
||||
sticky_device)
|
||||
|
||||
|
||||
xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
|
||||
core.AbstractUnit: lambda _, __: lambda _: core.unit,
|
||||
core.ShapedArray: array_result_handler,
|
||||
core.ConcreteArray: array_result_handler,
|
||||
}
|
||||
xla_result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
|
||||
result_handlers: Dict[
|
||||
Type[core.AbstractValue],
|
||||
Callable[[Optional[Device], Any], ResultHandler]] = {}
|
||||
result_handlers[core.AbstractUnit] = lambda _, __: lambda _: core.unit
|
||||
result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
|
||||
result_handlers[core.ShapedArray] = array_result_handler
|
||||
result_handlers[core.ConcreteArray] = array_result_handler
|
||||
|
||||
|
||||
def needs_check_special():
|
||||
@ -410,32 +452,26 @@ def _check_special(name, xla_shape, buf):
|
||||
|
||||
|
||||
def _execute_compiled(name: str, compiled: XlaExecutable,
|
||||
output_buffer_counts: Optional[Sequence[int]], handlers,
|
||||
kept_var_idx, *args):
|
||||
output_buffer_counts: Optional[Sequence[int]],
|
||||
result_handlers, kept_var_idx, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = list(
|
||||
itertools.chain.from_iterable(
|
||||
device_put(x, device)
|
||||
for i, x in enumerate(args)
|
||||
if x is not core.token and i in kept_var_idx))
|
||||
input_bufs = util.flatten(
|
||||
device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
check_special(name, out_bufs)
|
||||
if output_buffer_counts is None:
|
||||
return (handlers[0](*out_bufs),)
|
||||
return (result_handlers[0](*out_bufs),)
|
||||
return tuple(
|
||||
handler(*bs) for handler, bs in
|
||||
unsafe_zip(handlers, util.unflatten(out_bufs, output_buffer_counts)))
|
||||
unsafe_zip(result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
|
||||
|
||||
|
||||
def _execute_replicated(name: str, compiled: XlaExecutable,
|
||||
output_buffer_counts: Optional[Sequence[int]], handlers,
|
||||
kept_var_idx, *args):
|
||||
output_buffer_counts: Optional[Sequence[int]],
|
||||
result_handlers, kept_var_idx, *args):
|
||||
input_bufs = [
|
||||
list(
|
||||
itertools.chain.from_iterable(
|
||||
device_put(x, device)
|
||||
for i, x in enumerate(args)
|
||||
if x is not core.token and i in kept_var_idx))
|
||||
util.flatten(
|
||||
device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
|
||||
for device in compiled.local_devices()
|
||||
]
|
||||
out_bufs = [
|
||||
@ -444,10 +480,10 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
|
||||
]
|
||||
check_special(name, out_bufs)
|
||||
if output_buffer_counts is None:
|
||||
return (handlers[0](*out_bufs),)
|
||||
return (result_handlers[0](*out_bufs),)
|
||||
return tuple(
|
||||
handler(*bs) for handler, bs in
|
||||
unsafe_zip(handlers, util.unflatten(out_bufs, output_buffer_counts)))
|
||||
unsafe_zip(result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
|
||||
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
|
||||
@ -532,21 +568,23 @@ class XlaCompiledComputation:
|
||||
name: str,
|
||||
xla_computation,
|
||||
nreps: int,
|
||||
device,
|
||||
device: Optional[Device],
|
||||
backend,
|
||||
tuple_args: bool,
|
||||
in_avals,
|
||||
out_avals,
|
||||
kept_var_idx) -> 'XlaCompiledComputation':
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
sticky_device = device
|
||||
result_handlers = map(partial(aval_to_result_handler, sticky_device),
|
||||
out_avals)
|
||||
options = xb.get_compile_options(
|
||||
num_replicas=nreps,
|
||||
num_partitions=1,
|
||||
device_assignment=(device.id,) if device else None)
|
||||
device_assignment=(sticky_device.id,) if sticky_device else None)
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = compile_or_get_cached(backend, xla_computation, options)
|
||||
buffer_counts = (None if len(out_avals) == 1 else
|
||||
[len(xla.aval_to_xla_shapes(aval)) for aval in out_avals])
|
||||
[aval_to_num_buffers(aval) for aval in out_avals])
|
||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||
unsafe_call = partial(execute, name, compiled, buffer_counts,
|
||||
result_handlers, kept_var_idx)
|
||||
@ -610,17 +648,21 @@ def _device_put_scalar(x, device):
|
||||
|
||||
def _device_put_unit(_, device):
|
||||
backend = xb.get_device_backend(device)
|
||||
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
|
||||
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype(np.bool_)),
|
||||
device),)
|
||||
|
||||
def _device_put_token(_, device):
|
||||
backend = xb.get_device_backend(device)
|
||||
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype(np.bool_)),
|
||||
device),)
|
||||
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
|
||||
core.Unit: _device_put_unit
|
||||
}
|
||||
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {}
|
||||
device_put_handlers.update((t, _device_put_array) for t in array_types)
|
||||
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
|
||||
device_put_handlers[core.Token] = lambda x, _: (x,)
|
||||
device_put_handlers[core.Unit] = _device_put_unit
|
||||
device_put_handlers[core.Token] = _device_put_token
|
||||
|
||||
|
||||
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
|
||||
@ -670,9 +712,6 @@ ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
|
||||
masking.defvectorized(device_put_p)
|
||||
batching.defvectorized(device_put_p)
|
||||
|
||||
# TODO(phawkins): remove mlir->dispatch dependency and move this to the top.
|
||||
import jax.interpreters.mlir as mlir
|
||||
|
||||
def _device_put_lowering(ctx, avals_in, avals_out, x, *, device):
|
||||
return [x]
|
||||
|
||||
|
@ -54,7 +54,7 @@ complex_: type = np.complex128
|
||||
_default_types = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
|
||||
|
||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||
float0 = np.dtype([('float0', np.void, 0)])
|
||||
float0: np.dtype = np.dtype([('float0', np.void, 0)])
|
||||
|
||||
_dtype_to_32bit_dtype = {
|
||||
np.dtype('int64'): np.dtype('int32'),
|
||||
|
@ -19,30 +19,24 @@ import collections
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import io
|
||||
import itertools
|
||||
import typing
|
||||
from typing import (Any, Callable, Dict, Optional, Sequence, Type, Union, Tuple)
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Type, Union,
|
||||
Tuple)
|
||||
from typing_extensions import Protocol
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src.config import config
|
||||
from jax._src import ad_util
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import builtin
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import std
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import source_info_util
|
||||
import jax._src.util as util
|
||||
from jax.errors import UnexpectedTracerError
|
||||
import jax.interpreters.ad as ad
|
||||
import jax.interpreters.partial_eval as pe
|
||||
import jax.interpreters.xla as xla
|
||||
@ -315,10 +309,33 @@ def flatten_lowering_ir_args(
|
||||
) -> Sequence[Sequence[ir.Value]]:
|
||||
return util.flatten(map(_wrap_singleton_ir_values, xs))
|
||||
|
||||
def lower_jaxpr_to_module(jaxpr: core.ClosedJaxpr, platform: str,
|
||||
axis_env: xla.AxisEnv, name_stack: str) -> str:
|
||||
"""Lowers a top-level jaxpr to an MHLO module.
|
||||
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime."""
|
||||
ctx = LoweringContext(platform, axis_env, name_stack)
|
||||
if platform == "iree":
|
||||
ctx = ctx.replace(tuple_results=False)
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# TODO(phawkins): represent units with zero buffers at the runtime level.
|
||||
lower_jaxpr_to_fun(
|
||||
ctx, "main", jaxpr, public=True, replace_units_with_dummy=True,
|
||||
replace_tokens_with_dummy=True)
|
||||
|
||||
ctx.module.operation.verify()
|
||||
output = io.StringIO()
|
||||
ctx.module.operation.print(file=output, #enable_debug_info=True,
|
||||
print_generic_op_form=False)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
|
||||
jaxpr: core.ClosedJaxpr, *,
|
||||
public: bool = False,
|
||||
prune_tokens: bool = False) -> str:
|
||||
replace_units_with_dummy: bool = False,
|
||||
replace_tokens_with_dummy: bool = False) -> str:
|
||||
"""Lowers jaxpr and its callees to an IR function.
|
||||
|
||||
Assumes that an MLIR context, location, and insertion point are set.
|
||||
@ -329,22 +346,21 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
|
||||
so it is ok to use the same name multiple times.
|
||||
jaxpr: the jaxpr to lower.
|
||||
public: if true, the function's visibility is set to "public".
|
||||
prune_tokens: if true, tokens are pruned from the arguments and return
|
||||
values.
|
||||
replace_units_with_dummy: if true, unit arguments/return values are
|
||||
replaced with bool arrays of size [0].
|
||||
replace_tokens_with_dummy: if true, token arguments/return values are
|
||||
replaced with bool arrays of size [0].
|
||||
Returns the name of the function.
|
||||
"""
|
||||
# print(jaxpr.jaxpr)
|
||||
if prune_tokens:
|
||||
pruned_in_avals = [aval for aval in jaxpr.in_avals
|
||||
if aval is not core.abstract_token]
|
||||
pruned_out_avals = [aval for aval in jaxpr.out_avals
|
||||
if aval is not core.abstract_token]
|
||||
else:
|
||||
pruned_in_avals = jaxpr.in_avals
|
||||
pruned_out_avals = jaxpr.out_avals
|
||||
def aval_to_types(aval):
|
||||
if replace_units_with_dummy and aval is core.abstract_unit:
|
||||
aval = core.ShapedArray((), np.dtype(np.bool_))
|
||||
elif replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
aval = core.ShapedArray((), np.dtype(np.bool_))
|
||||
return aval_to_ir_types(aval)
|
||||
|
||||
input_types = map(aval_to_ir_types, pruned_in_avals)
|
||||
output_types = map(aval_to_ir_types, pruned_out_avals)
|
||||
input_types = map(aval_to_types, jaxpr.in_avals)
|
||||
output_types = map(aval_to_types, jaxpr.out_avals)
|
||||
flat_input_types = util.flatten(input_types)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
if ctx.tuple_results:
|
||||
@ -361,27 +377,28 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
|
||||
with ir.InsertionPoint(entry_block):
|
||||
unflattened_args = util.unflatten(entry_block.arguments,
|
||||
map(len, input_types))
|
||||
# If we pruned tokens out of the parameter list, create a new token and add
|
||||
# it here.
|
||||
if prune_tokens and len(pruned_in_avals) != len(jaxpr.in_avals):
|
||||
token = mhlo.CreateTokenOp(mhlo.TokenType.get()).results
|
||||
arg_iter = iter(unflattened_args)
|
||||
unflattened_args = [
|
||||
token if aval is core.abstract_token else next(arg_iter)
|
||||
for aval in jaxpr.in_avals
|
||||
]
|
||||
done = object()
|
||||
assert next(arg_iter, done) is done
|
||||
|
||||
args: List[List[ir.Value]] = []
|
||||
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
|
||||
if replace_units_with_dummy and aval is core.abstract_unit:
|
||||
args.append([])
|
||||
elif replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
|
||||
else:
|
||||
args.append(arg)
|
||||
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
|
||||
xla.wrap_name(name, 'jit'))
|
||||
out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
|
||||
jaxpr.jaxpr, map(ir_constants, jaxpr.consts),
|
||||
*unflattened_args)
|
||||
if prune_tokens:
|
||||
out_vals = [v for v, aval in zip(out_vals, jaxpr.out_avals)
|
||||
if aval is not core.abstract_token]
|
||||
flat_outputs = util.flatten(out_vals)
|
||||
*args)
|
||||
outs = []
|
||||
for aval, out in zip(jaxpr.out_avals, out_vals):
|
||||
if replace_units_with_dummy and aval is core.abstract_unit:
|
||||
outs.append(ir_constants(np.zeros((), np.bool_)))
|
||||
elif replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
outs.append(ir_constants(np.zeros((), np.bool_)))
|
||||
else:
|
||||
outs.append(out)
|
||||
flat_outputs = util.flatten(outs)
|
||||
if ctx.tuple_results:
|
||||
std.ReturnOp([mhlo.TupleOp(output_tuple_type, flat_outputs).result])
|
||||
else:
|
||||
@ -543,345 +560,6 @@ register_lowering(ad_util.stop_gradient_p,
|
||||
lambda ctx, avals_in, avals_out, x: [x])
|
||||
|
||||
|
||||
# # Computation dispatch
|
||||
|
||||
_aval_to_num_buffers: Dict[Type[core.AbstractValue],
|
||||
Callable[[core.AbstractValue], int]] = {}
|
||||
|
||||
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
|
||||
"""Returns the number of buffers in the runtime representation of `aval`.
|
||||
|
||||
Note: the compile-time representation may have more buffers! This is a small
|
||||
hack to deal with tokens that have no true runtime representation but have an
|
||||
IR type.
|
||||
|
||||
Must match the arity of the result of `aval_to_ir_types`."""
|
||||
try:
|
||||
return _aval_to_num_buffers[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
|
||||
|
||||
_aval_to_num_buffers[core.AbstractUnit] = lambda _: 0
|
||||
_aval_to_num_buffers[core.AbstractToken] = lambda _: 0
|
||||
_aval_to_num_buffers[core.ShapedArray] = lambda _: 1
|
||||
_aval_to_num_buffers[core.ConcreteArray] = lambda _: 1
|
||||
|
||||
|
||||
class ArgHandler(Protocol):
|
||||
def __call__(self, device: xla.Device, arg: Any) -> Sequence[xla.Buffer]:
|
||||
"""A argument handler unboxes raw buffers into their Python representation."""
|
||||
|
||||
_aval_to_arg_handler: Dict[
|
||||
Type[core.AbstractValue], Callable[[Any], ArgHandler]] = {}
|
||||
|
||||
def aval_to_arg_handler(aval: core.AbstractValue) -> ArgHandler:
|
||||
try:
|
||||
return _aval_to_arg_handler[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No arg_handler for type: {type(aval)}") from err
|
||||
|
||||
def _array_arg_handler(aval: core.ShapedArray) -> ArgHandler:
|
||||
return lambda device, val: xla.device_put(val, device)
|
||||
|
||||
_aval_to_arg_handler[core.AbstractUnit] = lambda _: lambda _device, _: ()
|
||||
_aval_to_arg_handler[core.AbstractToken] = lambda _: lambda _device, _: ()
|
||||
_aval_to_arg_handler[core.ShapedArray] = _array_arg_handler
|
||||
_aval_to_arg_handler[core.ConcreteArray] = _array_arg_handler
|
||||
|
||||
|
||||
if not MYPY:
|
||||
class ResultHandler(Protocol):
|
||||
def __call__(self, device: xla.Device, *args: xla.Buffer) -> Any:
|
||||
"""A result handler boxes raw buffers into their Python representation.
|
||||
|
||||
Inverse of ArgHandler."""
|
||||
else:
|
||||
ResultHandler = Any
|
||||
|
||||
_aval_to_result_handler: Dict[
|
||||
Type[core.AbstractValue], Callable[[Any], ResultHandler]] = {}
|
||||
|
||||
def aval_to_result_handler(aval: core.AbstractValue) -> ResultHandler:
|
||||
try:
|
||||
return _aval_to_result_handler[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No result_handler for type: {type(aval)}") from err
|
||||
|
||||
def _array_result_handler(aval: core.ShapedArray) -> ResultHandler:
|
||||
if aval.dtype is dtypes.float0:
|
||||
return lambda _device, _: np.zeros(aval.shape, dtypes.float0)
|
||||
aval = core.raise_to_shaped(aval)
|
||||
return lambda device, buffer: xla.make_device_array(aval, device, buffer)
|
||||
|
||||
_aval_to_result_handler[core.AbstractUnit] = lambda _: lambda _: core.unit
|
||||
_aval_to_result_handler[core.AbstractToken] = lambda _: lambda _: core.token
|
||||
_aval_to_result_handler[core.ShapedArray] = _array_result_handler
|
||||
_aval_to_result_handler[core.ConcreteArray] = _array_result_handler
|
||||
|
||||
|
||||
def _execute_compiled(name: str, compiled: xla.XlaExecutable,
|
||||
device: xla.Device, buffer_counts,
|
||||
arg_handlers, result_handlers, kept_var_idx, *args):
|
||||
input_bufs = util.flatten(
|
||||
arg_handler(device, x) for arg_handler, x in
|
||||
unsafe_zip(arg_handlers,
|
||||
(x for i, x in enumerate(args) if i in kept_var_idx)))
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
dispatch.check_special(name, out_bufs)
|
||||
return [handler(device, *bs) for handler, bs in
|
||||
zip(result_handlers, util.unflatten(out_bufs, buffer_counts))]
|
||||
|
||||
|
||||
def _execute_replicated(name: str, compiled: xla.XlaExecutable,
|
||||
device: xla.Device,
|
||||
buffer_counts, arg_handlers, result_handlers,
|
||||
kept_var_idx, *args):
|
||||
input_bufs = [
|
||||
util.flatten(
|
||||
arg_handler(device, x) for arg_handler, x in
|
||||
unsafe_zip(arg_handlers,
|
||||
(x for i, x in enumerate(args) if i in kept_var_idx)))
|
||||
for device in compiled.local_devices()
|
||||
]
|
||||
out_bufs = [
|
||||
buf[0] for buf in compiled.execute_sharded_on_local_devices(
|
||||
list(zip(*input_bufs)))
|
||||
]
|
||||
dispatch.check_special(name, out_bufs)
|
||||
return [handler(device, *bs) for handler, bs in
|
||||
zip(result_handlers, util.unflatten(out_bufs, buffer_counts))]
|
||||
|
||||
|
||||
def _execute_trivial(jaxpr, device: Optional[xla.Device], consts, buffer_counts,
|
||||
result_handlers, kept_var_idx, *args):
|
||||
env = {core.unitvar: core.unit}
|
||||
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
||||
map(env.setdefault, jaxpr.invars, pruned_args)
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
|
||||
for v in jaxpr.outvars]
|
||||
return [dispatch.device_put_p.bind(x, device=device) for x in outs]
|
||||
|
||||
|
||||
class XlaCompiledComputation:
|
||||
def __init__(self, xla_executable, unsafe_call):
|
||||
self._xla_executable = xla_executable
|
||||
self.unsafe_call = unsafe_call
|
||||
|
||||
@staticmethod
|
||||
def from_xla_computation(
|
||||
name: str,
|
||||
xla_computation,
|
||||
nreps: int,
|
||||
device,
|
||||
backend,
|
||||
tuple_args: bool,
|
||||
avals_in: Sequence[core.AbstractValue],
|
||||
avals_out: Sequence[core.AbstractValue],
|
||||
kept_var_idx) -> 'XlaCompiledComputation':
|
||||
arg_handlers = map(aval_to_arg_handler, avals_in)
|
||||
result_handlers = map(aval_to_result_handler, avals_out)
|
||||
options = xb.get_compile_options(
|
||||
num_replicas=nreps,
|
||||
num_partitions=1,
|
||||
device_assignment=(device.id,) if device else None)
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = dispatch.compile_or_get_cached(backend, xla_computation, options)
|
||||
buffer_counts = [aval_to_num_buffers(aval) for aval in avals_out]
|
||||
if nreps == 1:
|
||||
return XlaCompiledComputation(compiled, partial(
|
||||
_execute_compiled, name, compiled, device, buffer_counts,
|
||||
arg_handlers, result_handlers, kept_var_idx))
|
||||
else:
|
||||
return XlaCompiledComputation(compiled, partial(
|
||||
_execute_replicated, name, compiled, device, buffer_counts,
|
||||
arg_handlers, result_handlers, kept_var_idx))
|
||||
|
||||
def is_trivial(self):
|
||||
return self._xla_executable == None
|
||||
|
||||
def xla_executable(self):
|
||||
if self.is_trivial():
|
||||
raise ValueError('A trivial compiled computation has no XLA executable')
|
||||
return self._xla_executable
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, device, avals_in, avals_out,
|
||||
kept_var_idx) -> 'XlaCompiledComputation':
|
||||
result_handlers = map(aval_to_result_handler, avals_out)
|
||||
return XlaCompiledComputation(None, partial(
|
||||
_execute_trivial, jaxpr, device, consts, avals_out,
|
||||
result_handlers, kept_var_idx))
|
||||
|
||||
def call(self, *args):
|
||||
# TODO(apaszke,frostig): Check that args are compatible with input avals!
|
||||
return self.unsafe_call(*args)
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.call(*args)
|
||||
|
||||
|
||||
class XlaComputation:
|
||||
name: str
|
||||
_is_trivial: bool
|
||||
_executable: Optional['XlaCompiledComputation']
|
||||
|
||||
def __init__(self, name: str, hlo, is_trivial: bool, *compile_args):
|
||||
self.name = name
|
||||
self._hlo = hlo
|
||||
self._is_trivial = is_trivial
|
||||
self._executable = None
|
||||
self.compile_args = compile_args
|
||||
|
||||
def is_trivial(self):
|
||||
return self._is_trivial
|
||||
|
||||
def hlo(self):
|
||||
if self.is_trivial():
|
||||
raise ValueError('A trivial computation has no HLO')
|
||||
return self._hlo
|
||||
|
||||
def compile(self) -> 'XlaCompiledComputation':
|
||||
if self._executable is None:
|
||||
if self.is_trivial():
|
||||
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
|
||||
*self.compile_args)
|
||||
else:
|
||||
self._executable = XlaCompiledComputation.from_xla_computation(
|
||||
self.name, self.hlo(), *self.compile_args)
|
||||
return self._executable
|
||||
|
||||
|
||||
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs):
|
||||
return lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs).compile().unsafe_call
|
||||
|
||||
_xla_callable = lu.cache(_xla_callable_uncached)
|
||||
|
||||
# TODO(phawkins): refactor this code to share more with the xla.py version.
|
||||
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs):
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
avals_in, arg_devices = util.unzip2(arg_specs)
|
||||
jaxpr, avals_out, consts = pe.trace_to_jaxpr_final(
|
||||
fun, avals_in, pe.debug_info_final(fun, "jit"))
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
|
||||
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
|
||||
avals_in, arg_devices = util.unzip2(pruned_arg_specs)
|
||||
donated_invars = [
|
||||
x for i, x in enumerate(donated_invars) if i in kept_var_idx
|
||||
]
|
||||
map(dispatch.prefetch, itertools.chain(consts, dispatch.jaxpr_literals(jaxpr)))
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
nreps = dispatch.jaxpr_replicas(jaxpr)
|
||||
device = dispatch._xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
# and don't need to evaluate their arguments.
|
||||
if not jaxpr.eqns:
|
||||
return XlaComputation(name, None, True, jaxpr, consts, device, avals_in,
|
||||
avals_out, kept_var_idx)
|
||||
|
||||
if not dispatch._on_exit:
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority, "Compiling %s (%s) for args %s.",
|
||||
fun.__name__, id(fun), avals_in)
|
||||
|
||||
if nreps > 1:
|
||||
warnings.warn(f"The jitted function {fun.__name__} includes a pmap. Using "
|
||||
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
|
||||
"does not preserve sharded data representations and instead collects "
|
||||
"input and output arrays onto a single device. "
|
||||
"Consider removing the outer jit unless you know what you're doing. "
|
||||
"See https://github.com/google/jax/issues/2926.")
|
||||
|
||||
if nreps > xb.device_count(backend):
|
||||
raise ValueError(
|
||||
f"compiling computation that requires {nreps} replicas, but only "
|
||||
f"{xb.device_count(backend)} XLA devices are available")
|
||||
|
||||
if xb.process_count() > 1 and (nreps > 1 or dispatch.jaxpr_has_pmap(jaxpr)):
|
||||
raise NotImplementedError(
|
||||
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
||||
"extra data movement anyway, so maybe you don't want it after all).")
|
||||
|
||||
ctx = LoweringContext(backend.platform, xla.AxisEnv(nreps, (), ()), "")
|
||||
if backend.runtime_type == "iree":
|
||||
tuple_args = False
|
||||
ctx = ctx.replace(tuple_results=False)
|
||||
else:
|
||||
tuple_args = len(avals_in) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# XLA doesn't have a runtime representation of tokens, so we prune them out
|
||||
# of the arguments and return values of the top-level function. This is fine
|
||||
# since the purpose of tokens is to preserve ordering inside compiled
|
||||
# functions.
|
||||
lower_jaxpr_to_fun(ctx, "main", core.ClosedJaxpr(jaxpr, consts),
|
||||
public=True, prune_tokens=True)
|
||||
|
||||
assert not any(donated_invars), donated_invars
|
||||
# TODO(b/203122001): implement buffer donation.
|
||||
# if backend.platform in ("gpu", "tpu"):
|
||||
# donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
|
||||
# if any(donated_invars):
|
||||
# # TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
# unused_donations = [str(c.GetShape(a))
|
||||
# for a, d in zip(xla_args, donated_invars) if d]
|
||||
# warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations)))
|
||||
ctx.module.operation.verify()
|
||||
output = io.StringIO()
|
||||
ctx.module.operation.print(file=output, #enable_debug_info=True,
|
||||
print_generic_op_form=False)
|
||||
module = output.getvalue()
|
||||
# print("MLIR module to be compiled:")
|
||||
# print(module)
|
||||
return XlaComputation(
|
||||
name, module, False, nreps, device, backend, tuple_args, avals_in,
|
||||
avals_out, kept_var_idx)
|
||||
|
||||
|
||||
def _xla_call_impl_mlir(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
donated_invars, inline):
|
||||
del inline # Only used at tracing time
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*unsafe_map(dispatch.arg_spec, args))
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
||||
@util.cache()
|
||||
def _xla_primitive_callable(prim, *arg_specs: dispatch.ArgSpec, **params):
|
||||
avals, arg_devices = util.unzip2(arg_specs)
|
||||
donated_invars = (False,) * len(arg_specs)
|
||||
device = dispatch._device_from_arg_devices(arg_devices)
|
||||
def prim_fun(*args):
|
||||
out = prim.bind(*args, **params)
|
||||
if prim.multiple_results:
|
||||
return out
|
||||
else:
|
||||
return out,
|
||||
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
|
||||
prim.name, donated_invars, *arg_specs)
|
||||
if not prim.multiple_results:
|
||||
return lambda *args, **kw: compiled(*args, **kw)[0]
|
||||
else:
|
||||
return compiled
|
||||
|
||||
def apply_primitive(prim, *args, **params):
|
||||
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
||||
compiled_fun = _xla_primitive_callable(prim, *unsafe_map(dispatch.arg_spec, args),
|
||||
**params)
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
||||
# MLIR lowerings for lax primitives
|
||||
|
||||
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
|
||||
|
@ -356,6 +356,14 @@ def primitive_subcomputation(platform: str, prim: core.Primitive,
|
||||
# sharding annotation set) and replicated.
|
||||
_replicated_param = object()
|
||||
|
||||
def _token_param_shape():
|
||||
"""Shape used in place of tokens as top-level computation arguments."""
|
||||
return xc.Shape.array_shape(np.dtype(np.bool_), [])
|
||||
|
||||
def _make_token_return_value(c):
|
||||
"""Value used in place of tokens as a top-level computation return value."""
|
||||
return xops.Constant(c, np.zeros((), dtype=np.dtype(np.bool_)))
|
||||
|
||||
def _xla_callable_args(
|
||||
c, avals, tuple_args, *,
|
||||
replicated=None,
|
||||
@ -375,8 +383,8 @@ def _xla_callable_args(
|
||||
parts = [_replicated_param if part is None else part
|
||||
for part in partitions]
|
||||
counts = it.count()
|
||||
xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto)
|
||||
if not (filter_tokens and a is abstract_token) else xops.CreateToken(c)
|
||||
xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto,
|
||||
filter_tokens)
|
||||
for (a, r, p) in safe_zip(avals, replicated, parts)
|
||||
for xla_shape in aval_to_xla_shapes(a)]
|
||||
if donated_invars is not None:
|
||||
@ -395,25 +403,33 @@ def _xla_callable_args(
|
||||
else:
|
||||
tuple_parts = tuple(partitions)
|
||||
tuple_shape = xc.Shape.tuple_shape(
|
||||
[shape for a in avals for shape in aval_to_xla_shapes(a)
|
||||
if not (filter_tokens and a is abstract_token)])
|
||||
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts, partitions_proto)
|
||||
xla_inputs = iter(xla_destructure(c, tuple_param))
|
||||
xla_args = [next(xla_inputs) if not (filter_tokens and a is abstract_token)
|
||||
else xops.CreateToken(c) for a in avals]
|
||||
assert next(xla_inputs, None) is None
|
||||
[shape if not (filter_tokens and a is abstract_token)
|
||||
else _token_param_shape()
|
||||
for a in avals for shape in aval_to_xla_shapes(a)])
|
||||
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts,
|
||||
partitions_proto, filter_tokens)
|
||||
xla_args = [v if not (filter_tokens and a is abstract_token)
|
||||
else xops.CreateToken(c)
|
||||
for a, v in zip(avals, xla_destructure(c, tuple_param))]
|
||||
return xla_args, donated_invars
|
||||
|
||||
def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_proto):
|
||||
def _xla_param(builder, param_num, xla_shape, replicated, partitions,
|
||||
parts_proto, filter_tokens):
|
||||
is_token = xla_shape.is_token()
|
||||
if filter_tokens and is_token:
|
||||
xla_shape = _token_param_shape()
|
||||
make_param = partial(xb.parameter, builder, param_num, xla_shape,
|
||||
replicated=replicated)
|
||||
with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
|
||||
if partitions is None:
|
||||
return make_param()
|
||||
out = make_param()
|
||||
elif partitions is _replicated_param:
|
||||
return with_sharding(builder, None, make_param)
|
||||
out = with_sharding(builder, None, make_param)
|
||||
else:
|
||||
return with_sharding(builder, partitions, make_param)
|
||||
out = with_sharding(builder, partitions, make_param)
|
||||
if filter_tokens and is_token:
|
||||
out = xops.CreateToken(builder)
|
||||
return out
|
||||
|
||||
|
||||
### compiling jaxprs
|
||||
|
@ -132,7 +132,8 @@ core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
|
||||
dispatch.device_put_handlers[SparseArray] = sparse_array_device_put_handler
|
||||
dispatch.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
|
||||
dispatch.result_handlers[AbstractSparseArray] = sparse_array_result_handler
|
||||
dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2
|
||||
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
|
||||
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
|
||||
|
||||
@ -260,7 +261,8 @@ core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
|
||||
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
|
||||
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
|
||||
dispatch.device_put_handlers[Empty] = lambda _, __: ()
|
||||
dispatch.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
|
||||
dispatch.result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
|
||||
dispatch.num_buffers_handlers[AbstractEmpty] = lambda _: 0
|
||||
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user