Merge most of the MLIR JIT dispatch logic into the common primitive and JIT computation path.

Change the representation of both units and tokens at the runtime level to be a single buffer with shape pred[0]. While the MLIR lowering is happy to have a non 1:1 mapping between avals and IR values, the XLA lowering is not, so until we remove the XLA lowering it's easiest just to keep the mapping 1:1.

PiperOrigin-RevId: 412957231
This commit is contained in:
Peter Hawkins 2021-11-29 12:39:19 -08:00 committed by jax authors
parent 5084a12034
commit 12512cc96a
5 changed files with 204 additions and 469 deletions

View File

@ -18,6 +18,7 @@ from functools import partial
import itertools
from typing import (
Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union)
from typing_extensions import Protocol
import warnings
from absl import logging
@ -28,6 +29,7 @@ from jax import linear_util as lu
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.masking as masking
import jax.interpreters.mlir as mlir
import jax.interpreters.xla as xla
import jax.interpreters.partial_eval as pe
from jax.errors import UnexpectedTracerError
@ -42,6 +44,7 @@ from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
MYPY = False # Are we currently type checking with mypy?
xe = xc._xla
@ -71,9 +74,6 @@ def arg_spec(x: Any) -> ArgSpec:
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
if config.jax_enable_mlir:
import jax.interpreters.mlir
return jax.interpreters.mlir.apply_primitive(prim, *args, **params)
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
**params)
return compiled_fun(*args)
@ -123,12 +123,6 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
if config.jax_enable_mlir:
import jax.interpreters.mlir
return jax.interpreters.mlir._xla_call_impl_mlir(
fun, *args, device=device, backend=backend, name=name,
donated_invars=donated_invars, inline=inline)
del inline # Only used at tracing time
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
@ -220,31 +214,48 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
# pass long arg lists as tuple for TPU
tuple_args = len(abstract_args) > 100
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit'))
module: Any
if config.jax_enable_mlir:
# TODO(b/203122001): implement buffer donation.
assert not any(donated_invars), donated_invars
module = mlir.lower_jaxpr_to_module(
core.ClosedJaxpr(jaxpr, consts), backend.platform, axis_env, name_stack)
else:
# XLA HLO lowering path
c = xc.XlaBuilder(f"jit_{fun.__name__}")
xla_consts = xla._xla_consts(c, consts)
xla_args, donated_invars = xla._xla_callable_args(
c, abstract_args, tuple_args, donated_invars=donated_invars)
platform = backend.platform
ctx = xla.TranslationContext(c, platform, axis_env, name_stack)
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
# Replace tokens with a dummy array value, because the runtime cannot
# handle token arguments.
out_aval_lens = [len(xla.aval_to_xla_shapes(a)) for a in out_avals]
out_nodes = util.flatten(
[[xla._make_token_return_value(c)] if a is core.abstract_token
else v
for a, v in zip(out_avals, util.unflatten(out_nodes, out_aval_lens))])
c = xc.XlaBuilder(f"jit_{fun.__name__}")
xla_consts = xla._xla_consts(c, consts)
xla_args, donated_invars = xla._xla_callable_args(c, abstract_args, tuple_args,
donated_invars=donated_invars)
platform = backend.platform
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(nreps, (), ()),
xla.extend_name_stack(xla.wrap_name(name, 'jit')))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
if platform in ("gpu", "tpu"):
donated_invars = xla.set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
warnings.warn("Some donated buffers were not usable: {}".format(
", ".join(unused_donations)))
built = c.build(output)
# There is a non-zero cost to building an output tuple, particularly on TPU.
# Avoid it if the output arity is 1.
output = out_nodes[0] if len(out_nodes) == 1 else xc.ops.Tuple(c, out_nodes)
if platform in ("gpu", "tpu"):
donated_invars = xla.set_up_aliases(
c, xla_args, c.GetShape(output), donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
warnings.warn("Some donated buffers were not usable: {}".format(
", ".join(unused_donations)))
module = c.build(output)
return XlaComputation(
name, built, False, donated_invars, nreps, device, backend, tuple_args,
name, module, False, donated_invars, nreps, device, backend, tuple_args,
abstract_args, out_avals, kept_var_idx)
@ -369,27 +380,58 @@ def _xla_callable_device(nreps, backend, device, arg_devices):
assert False # Unreachable given the error check in _xla_callable
# Result handlers
# Argument and result handlers
def aval_to_result_handler(device: Optional[Device],
aval: core.AbstractValue) -> Callable:
num_buffers_handlers: Dict[Type[core.AbstractValue],
Callable[[core.AbstractValue], int]] = {}
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
"""Returns the number of buffers in the runtime representation of `aval`.
In general this may differ from the number of buffers in the compiler-IR
representation of the same value.
"""
try:
return xla_result_handlers[type(aval)](device, aval)
return num_buffers_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
# TODO(phawkins): use zero buffers to represent a unit.
num_buffers_handlers[core.AbstractUnit] = lambda _: 1
num_buffers_handlers[core.AbstractToken] = lambda _: 1
num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
if MYPY:
ResultHandler = Any
else:
class ResultHandler(Protocol):
def __call__(self, *args: xla.Buffer) -> Any:
"""Boxes raw buffers into their user-facing representation."""
def aval_to_result_handler(sticky_device: Optional[Device],
aval: core.AbstractValue) -> ResultHandler:
try:
return result_handlers[type(aval)](sticky_device, aval)
except KeyError as err:
raise TypeError(f"No result handler for type: {type(aval)}") from err
def array_result_handler(sticky_device: Optional[Device],
aval: core.ShapedArray):
if aval.dtype is dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0)
return partial(device_array.make_device_array, core.raise_to_shaped(aval), device)
return partial(device_array.make_device_array, core.raise_to_shaped(aval),
sticky_device)
xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
core.AbstractUnit: lambda _, __: lambda _: core.unit,
core.ShapedArray: array_result_handler,
core.ConcreteArray: array_result_handler,
}
xla_result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
result_handlers: Dict[
Type[core.AbstractValue],
Callable[[Optional[Device], Any], ResultHandler]] = {}
result_handlers[core.AbstractUnit] = lambda _, __: lambda _: core.unit
result_handlers[core.AbstractToken] = lambda _, __: lambda _: core.token
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
def needs_check_special():
@ -410,32 +452,26 @@ def _check_special(name, xla_shape, buf):
def _execute_compiled(name: str, compiled: XlaExecutable,
output_buffer_counts: Optional[Sequence[int]], handlers,
kept_var_idx, *args):
output_buffer_counts: Optional[Sequence[int]],
result_handlers, kept_var_idx, *args):
device, = compiled.local_devices()
input_bufs = list(
itertools.chain.from_iterable(
device_put(x, device)
for i, x in enumerate(args)
if x is not core.token and i in kept_var_idx))
input_bufs = util.flatten(
device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
out_bufs = compiled.execute(input_bufs)
check_special(name, out_bufs)
if output_buffer_counts is None:
return (handlers[0](*out_bufs),)
return (result_handlers[0](*out_bufs),)
return tuple(
handler(*bs) for handler, bs in
unsafe_zip(handlers, util.unflatten(out_bufs, output_buffer_counts)))
unsafe_zip(result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
def _execute_replicated(name: str, compiled: XlaExecutable,
output_buffer_counts: Optional[Sequence[int]], handlers,
kept_var_idx, *args):
output_buffer_counts: Optional[Sequence[int]],
result_handlers, kept_var_idx, *args):
input_bufs = [
list(
itertools.chain.from_iterable(
device_put(x, device)
for i, x in enumerate(args)
if x is not core.token and i in kept_var_idx))
util.flatten(
device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx)
for device in compiled.local_devices()
]
out_bufs = [
@ -444,10 +480,10 @@ def _execute_replicated(name: str, compiled: XlaExecutable,
]
check_special(name, out_bufs)
if output_buffer_counts is None:
return (handlers[0](*out_bufs),)
return (result_handlers[0](*out_bufs),)
return tuple(
handler(*bs) for handler, bs in
unsafe_zip(handlers, util.unflatten(out_bufs, output_buffer_counts)))
unsafe_zip(result_handlers, util.unflatten(out_bufs, output_buffer_counts)))
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
@ -532,21 +568,23 @@ class XlaCompiledComputation:
name: str,
xla_computation,
nreps: int,
device,
device: Optional[Device],
backend,
tuple_args: bool,
in_avals,
out_avals,
kept_var_idx) -> 'XlaCompiledComputation':
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
sticky_device = device
result_handlers = map(partial(aval_to_result_handler, sticky_device),
out_avals)
options = xb.get_compile_options(
num_replicas=nreps,
num_partitions=1,
device_assignment=(device.id,) if device else None)
device_assignment=(sticky_device.id,) if sticky_device else None)
options.parameter_is_tupled_arguments = tuple_args
compiled = compile_or_get_cached(backend, xla_computation, options)
buffer_counts = (None if len(out_avals) == 1 else
[len(xla.aval_to_xla_shapes(aval)) for aval in out_avals])
[aval_to_num_buffers(aval) for aval in out_avals])
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, buffer_counts,
result_handlers, kept_var_idx)
@ -610,17 +648,21 @@ def _device_put_scalar(x, device):
def _device_put_unit(_, device):
backend = xb.get_device_backend(device)
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype(np.bool_)),
device),)
def _device_put_token(_, device):
backend = xb.get_device_backend(device)
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype(np.bool_)),
device),)
_scalar_types = dtypes.python_scalar_dtypes.keys()
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
core.Unit: _device_put_unit
}
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {}
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
device_put_handlers[core.Token] = lambda x, _: (x,)
device_put_handlers[core.Unit] = _device_put_unit
device_put_handlers[core.Token] = _device_put_token
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
@ -670,9 +712,6 @@ ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
masking.defvectorized(device_put_p)
batching.defvectorized(device_put_p)
# TODO(phawkins): remove mlir->dispatch dependency and move this to the top.
import jax.interpreters.mlir as mlir
def _device_put_lowering(ctx, avals_in, avals_out, x, *, device):
return [x]

View File

@ -54,7 +54,7 @@ complex_: type = np.complex128
_default_types = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0 = np.dtype([('float0', np.void, 0)])
float0: np.dtype = np.dtype([('float0', np.void, 0)])
_dtype_to_32bit_dtype = {
np.dtype('int64'): np.dtype('int32'),

View File

@ -19,30 +19,24 @@ import collections
import dataclasses
from functools import partial
import io
import itertools
import typing
from typing import (Any, Callable, Dict, Optional, Sequence, Type, Union, Tuple)
from typing import (Any, Callable, Dict, List, Optional, Sequence, Type, Union,
Tuple)
from typing_extensions import Protocol
import warnings
from absl import logging
from jax import core
from jax import linear_util as lu
from jax._src.config import config
from jax._src import ad_util
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import builtin
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import std
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src import source_info_util
import jax._src.util as util
from jax.errors import UnexpectedTracerError
import jax.interpreters.ad as ad
import jax.interpreters.partial_eval as pe
import jax.interpreters.xla as xla
@ -315,10 +309,33 @@ def flatten_lowering_ir_args(
) -> Sequence[Sequence[ir.Value]]:
return util.flatten(map(_wrap_singleton_ir_values, xs))
def lower_jaxpr_to_module(jaxpr: core.ClosedJaxpr, platform: str,
axis_env: xla.AxisEnv, name_stack: str) -> str:
"""Lowers a top-level jaxpr to an MHLO module.
Handles the quirks of the argument/return value passing conventions of the
runtime."""
ctx = LoweringContext(platform, axis_env, name_stack)
if platform == "iree":
ctx = ctx.replace(tuple_results=False)
with ctx.context, ir.Location.unknown(ctx.context):
# TODO(phawkins): represent units with zero buffers at the runtime level.
lower_jaxpr_to_fun(
ctx, "main", jaxpr, public=True, replace_units_with_dummy=True,
replace_tokens_with_dummy=True)
ctx.module.operation.verify()
output = io.StringIO()
ctx.module.operation.print(file=output, #enable_debug_info=True,
print_generic_op_form=False)
return output.getvalue()
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
jaxpr: core.ClosedJaxpr, *,
public: bool = False,
prune_tokens: bool = False) -> str:
replace_units_with_dummy: bool = False,
replace_tokens_with_dummy: bool = False) -> str:
"""Lowers jaxpr and its callees to an IR function.
Assumes that an MLIR context, location, and insertion point are set.
@ -329,22 +346,21 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
so it is ok to use the same name multiple times.
jaxpr: the jaxpr to lower.
public: if true, the function's visibility is set to "public".
prune_tokens: if true, tokens are pruned from the arguments and return
values.
replace_units_with_dummy: if true, unit arguments/return values are
replaced with bool arrays of size [0].
replace_tokens_with_dummy: if true, token arguments/return values are
replaced with bool arrays of size [0].
Returns the name of the function.
"""
# print(jaxpr.jaxpr)
if prune_tokens:
pruned_in_avals = [aval for aval in jaxpr.in_avals
if aval is not core.abstract_token]
pruned_out_avals = [aval for aval in jaxpr.out_avals
if aval is not core.abstract_token]
else:
pruned_in_avals = jaxpr.in_avals
pruned_out_avals = jaxpr.out_avals
def aval_to_types(aval):
if replace_units_with_dummy and aval is core.abstract_unit:
aval = core.ShapedArray((), np.dtype(np.bool_))
elif replace_tokens_with_dummy and aval is core.abstract_token:
aval = core.ShapedArray((), np.dtype(np.bool_))
return aval_to_ir_types(aval)
input_types = map(aval_to_ir_types, pruned_in_avals)
output_types = map(aval_to_ir_types, pruned_out_avals)
input_types = map(aval_to_types, jaxpr.in_avals)
output_types = map(aval_to_types, jaxpr.out_avals)
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
if ctx.tuple_results:
@ -361,27 +377,28 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
with ir.InsertionPoint(entry_block):
unflattened_args = util.unflatten(entry_block.arguments,
map(len, input_types))
# If we pruned tokens out of the parameter list, create a new token and add
# it here.
if prune_tokens and len(pruned_in_avals) != len(jaxpr.in_avals):
token = mhlo.CreateTokenOp(mhlo.TokenType.get()).results
arg_iter = iter(unflattened_args)
unflattened_args = [
token if aval is core.abstract_token else next(arg_iter)
for aval in jaxpr.in_avals
]
done = object()
assert next(arg_iter, done) is done
args: List[List[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_units_with_dummy and aval is core.abstract_unit:
args.append([])
elif replace_tokens_with_dummy and aval is core.abstract_token:
args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
else:
args.append(arg)
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
xla.wrap_name(name, 'jit'))
out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, map(ir_constants, jaxpr.consts),
*unflattened_args)
if prune_tokens:
out_vals = [v for v, aval in zip(out_vals, jaxpr.out_avals)
if aval is not core.abstract_token]
flat_outputs = util.flatten(out_vals)
*args)
outs = []
for aval, out in zip(jaxpr.out_avals, out_vals):
if replace_units_with_dummy and aval is core.abstract_unit:
outs.append(ir_constants(np.zeros((), np.bool_)))
elif replace_tokens_with_dummy and aval is core.abstract_token:
outs.append(ir_constants(np.zeros((), np.bool_)))
else:
outs.append(out)
flat_outputs = util.flatten(outs)
if ctx.tuple_results:
std.ReturnOp([mhlo.TupleOp(output_tuple_type, flat_outputs).result])
else:
@ -543,345 +560,6 @@ register_lowering(ad_util.stop_gradient_p,
lambda ctx, avals_in, avals_out, x: [x])
# # Computation dispatch
_aval_to_num_buffers: Dict[Type[core.AbstractValue],
Callable[[core.AbstractValue], int]] = {}
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
"""Returns the number of buffers in the runtime representation of `aval`.
Note: the compile-time representation may have more buffers! This is a small
hack to deal with tokens that have no true runtime representation but have an
IR type.
Must match the arity of the result of `aval_to_ir_types`."""
try:
return _aval_to_num_buffers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
_aval_to_num_buffers[core.AbstractUnit] = lambda _: 0
_aval_to_num_buffers[core.AbstractToken] = lambda _: 0
_aval_to_num_buffers[core.ShapedArray] = lambda _: 1
_aval_to_num_buffers[core.ConcreteArray] = lambda _: 1
class ArgHandler(Protocol):
def __call__(self, device: xla.Device, arg: Any) -> Sequence[xla.Buffer]:
"""A argument handler unboxes raw buffers into their Python representation."""
_aval_to_arg_handler: Dict[
Type[core.AbstractValue], Callable[[Any], ArgHandler]] = {}
def aval_to_arg_handler(aval: core.AbstractValue) -> ArgHandler:
try:
return _aval_to_arg_handler[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No arg_handler for type: {type(aval)}") from err
def _array_arg_handler(aval: core.ShapedArray) -> ArgHandler:
return lambda device, val: xla.device_put(val, device)
_aval_to_arg_handler[core.AbstractUnit] = lambda _: lambda _device, _: ()
_aval_to_arg_handler[core.AbstractToken] = lambda _: lambda _device, _: ()
_aval_to_arg_handler[core.ShapedArray] = _array_arg_handler
_aval_to_arg_handler[core.ConcreteArray] = _array_arg_handler
if not MYPY:
class ResultHandler(Protocol):
def __call__(self, device: xla.Device, *args: xla.Buffer) -> Any:
"""A result handler boxes raw buffers into their Python representation.
Inverse of ArgHandler."""
else:
ResultHandler = Any
_aval_to_result_handler: Dict[
Type[core.AbstractValue], Callable[[Any], ResultHandler]] = {}
def aval_to_result_handler(aval: core.AbstractValue) -> ResultHandler:
try:
return _aval_to_result_handler[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No result_handler for type: {type(aval)}") from err
def _array_result_handler(aval: core.ShapedArray) -> ResultHandler:
if aval.dtype is dtypes.float0:
return lambda _device, _: np.zeros(aval.shape, dtypes.float0)
aval = core.raise_to_shaped(aval)
return lambda device, buffer: xla.make_device_array(aval, device, buffer)
_aval_to_result_handler[core.AbstractUnit] = lambda _: lambda _: core.unit
_aval_to_result_handler[core.AbstractToken] = lambda _: lambda _: core.token
_aval_to_result_handler[core.ShapedArray] = _array_result_handler
_aval_to_result_handler[core.ConcreteArray] = _array_result_handler
def _execute_compiled(name: str, compiled: xla.XlaExecutable,
device: xla.Device, buffer_counts,
arg_handlers, result_handlers, kept_var_idx, *args):
input_bufs = util.flatten(
arg_handler(device, x) for arg_handler, x in
unsafe_zip(arg_handlers,
(x for i, x in enumerate(args) if i in kept_var_idx)))
out_bufs = compiled.execute(input_bufs)
dispatch.check_special(name, out_bufs)
return [handler(device, *bs) for handler, bs in
zip(result_handlers, util.unflatten(out_bufs, buffer_counts))]
def _execute_replicated(name: str, compiled: xla.XlaExecutable,
device: xla.Device,
buffer_counts, arg_handlers, result_handlers,
kept_var_idx, *args):
input_bufs = [
util.flatten(
arg_handler(device, x) for arg_handler, x in
unsafe_zip(arg_handlers,
(x for i, x in enumerate(args) if i in kept_var_idx)))
for device in compiled.local_devices()
]
out_bufs = [
buf[0] for buf in compiled.execute_sharded_on_local_devices(
list(zip(*input_bufs)))
]
dispatch.check_special(name, out_bufs)
return [handler(device, *bs) for handler, bs in
zip(result_handlers, util.unflatten(out_bufs, buffer_counts))]
def _execute_trivial(jaxpr, device: Optional[xla.Device], consts, buffer_counts,
result_handlers, kept_var_idx, *args):
env = {core.unitvar: core.unit}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return [dispatch.device_put_p.bind(x, device=device) for x in outs]
class XlaCompiledComputation:
def __init__(self, xla_executable, unsafe_call):
self._xla_executable = xla_executable
self.unsafe_call = unsafe_call
@staticmethod
def from_xla_computation(
name: str,
xla_computation,
nreps: int,
device,
backend,
tuple_args: bool,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
kept_var_idx) -> 'XlaCompiledComputation':
arg_handlers = map(aval_to_arg_handler, avals_in)
result_handlers = map(aval_to_result_handler, avals_out)
options = xb.get_compile_options(
num_replicas=nreps,
num_partitions=1,
device_assignment=(device.id,) if device else None)
options.parameter_is_tupled_arguments = tuple_args
compiled = dispatch.compile_or_get_cached(backend, xla_computation, options)
buffer_counts = [aval_to_num_buffers(aval) for aval in avals_out]
if nreps == 1:
return XlaCompiledComputation(compiled, partial(
_execute_compiled, name, compiled, device, buffer_counts,
arg_handlers, result_handlers, kept_var_idx))
else:
return XlaCompiledComputation(compiled, partial(
_execute_replicated, name, compiled, device, buffer_counts,
arg_handlers, result_handlers, kept_var_idx))
def is_trivial(self):
return self._xla_executable == None
def xla_executable(self):
if self.is_trivial():
raise ValueError('A trivial compiled computation has no XLA executable')
return self._xla_executable
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, device, avals_in, avals_out,
kept_var_idx) -> 'XlaCompiledComputation':
result_handlers = map(aval_to_result_handler, avals_out)
return XlaCompiledComputation(None, partial(
_execute_trivial, jaxpr, device, consts, avals_out,
result_handlers, kept_var_idx))
def call(self, *args):
# TODO(apaszke,frostig): Check that args are compatible with input avals!
return self.unsafe_call(*args)
def __call__(self, *args):
return self.call(*args)
class XlaComputation:
name: str
_is_trivial: bool
_executable: Optional['XlaCompiledComputation']
def __init__(self, name: str, hlo, is_trivial: bool, *compile_args):
self.name = name
self._hlo = hlo
self._is_trivial = is_trivial
self._executable = None
self.compile_args = compile_args
def is_trivial(self):
return self._is_trivial
def hlo(self):
if self.is_trivial():
raise ValueError('A trivial computation has no HLO')
return self._hlo
def compile(self) -> 'XlaCompiledComputation':
if self._executable is None:
if self.is_trivial():
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
*self.compile_args)
else:
self._executable = XlaCompiledComputation.from_xla_computation(
self.name, self.hlo(), *self.compile_args)
return self._executable
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs):
return lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs).compile().unsafe_call
_xla_callable = lu.cache(_xla_callable_uncached)
# TODO(phawkins): refactor this code to share more with the xla.py version.
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
avals_in, arg_devices = util.unzip2(arg_specs)
jaxpr, avals_out, consts = pe.trace_to_jaxpr_final(
fun, avals_in, pe.debug_info_final(fun, "jit"))
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
pruned_arg_specs = (a for i, a in enumerate(arg_specs) if i in kept_var_idx)
avals_in, arg_devices = util.unzip2(pruned_arg_specs)
donated_invars = [
x for i, x in enumerate(donated_invars) if i in kept_var_idx
]
map(dispatch.prefetch, itertools.chain(consts, dispatch.jaxpr_literals(jaxpr)))
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
nreps = dispatch.jaxpr_replicas(jaxpr)
device = dispatch._xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if not jaxpr.eqns:
return XlaComputation(name, None, True, jaxpr, consts, device, avals_in,
avals_out, kept_var_idx)
if not dispatch._on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logging.log(log_priority, "Compiling %s (%s) for args %s.",
fun.__name__, id(fun), avals_in)
if nreps > 1:
warnings.warn(f"The jitted function {fun.__name__} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")
if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation that requires {nreps} replicas, but only "
f"{xb.device_count(backend)} XLA devices are available")
if xb.process_count() > 1 and (nreps > 1 or dispatch.jaxpr_has_pmap(jaxpr)):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
ctx = LoweringContext(backend.platform, xla.AxisEnv(nreps, (), ()), "")
if backend.runtime_type == "iree":
tuple_args = False
ctx = ctx.replace(tuple_results=False)
else:
tuple_args = len(avals_in) > 100 # pass long arg lists as tuple for TPU
with ctx.context, ir.Location.unknown(ctx.context):
# XLA doesn't have a runtime representation of tokens, so we prune them out
# of the arguments and return values of the top-level function. This is fine
# since the purpose of tokens is to preserve ordering inside compiled
# functions.
lower_jaxpr_to_fun(ctx, "main", core.ClosedJaxpr(jaxpr, consts),
public=True, prune_tokens=True)
assert not any(donated_invars), donated_invars
# TODO(b/203122001): implement buffer donation.
# if backend.platform in ("gpu", "tpu"):
# donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
# if any(donated_invars):
# # TODO(tomhennigan): At call time we should mark these buffers as deleted.
# unused_donations = [str(c.GetShape(a))
# for a, d in zip(xla_args, donated_invars) if d]
# warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations)))
ctx.module.operation.verify()
output = io.StringIO()
ctx.module.operation.print(file=output, #enable_debug_info=True,
print_generic_op_form=False)
module = output.getvalue()
# print("MLIR module to be compiled:")
# print(module)
return XlaComputation(
name, module, False, nreps, device, backend, tuple_args, avals_in,
avals_out, kept_var_idx)
def _xla_call_impl_mlir(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(dispatch.arg_spec, args))
return compiled_fun(*args)
@util.cache()
def _xla_primitive_callable(prim, *arg_specs: dispatch.ArgSpec, **params):
avals, arg_devices = util.unzip2(arg_specs)
donated_invars = (False,) * len(arg_specs)
device = dispatch._device_from_arg_devices(arg_devices)
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
return out
else:
return out,
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
prim.name, donated_invars, *arg_specs)
if not prim.multiple_results:
return lambda *args, **kw: compiled(*args, **kw)[0]
else:
return compiled
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
compiled_fun = _xla_primitive_callable(prim, *unsafe_map(dispatch.arg_spec, args),
**params)
return compiled_fun(*args)
# MLIR lowerings for lax primitives
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext,

View File

@ -356,6 +356,14 @@ def primitive_subcomputation(platform: str, prim: core.Primitive,
# sharding annotation set) and replicated.
_replicated_param = object()
def _token_param_shape():
"""Shape used in place of tokens as top-level computation arguments."""
return xc.Shape.array_shape(np.dtype(np.bool_), [])
def _make_token_return_value(c):
"""Value used in place of tokens as a top-level computation return value."""
return xops.Constant(c, np.zeros((), dtype=np.dtype(np.bool_)))
def _xla_callable_args(
c, avals, tuple_args, *,
replicated=None,
@ -375,8 +383,8 @@ def _xla_callable_args(
parts = [_replicated_param if part is None else part
for part in partitions]
counts = it.count()
xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto)
if not (filter_tokens and a is abstract_token) else xops.CreateToken(c)
xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto,
filter_tokens)
for (a, r, p) in safe_zip(avals, replicated, parts)
for xla_shape in aval_to_xla_shapes(a)]
if donated_invars is not None:
@ -395,25 +403,33 @@ def _xla_callable_args(
else:
tuple_parts = tuple(partitions)
tuple_shape = xc.Shape.tuple_shape(
[shape for a in avals for shape in aval_to_xla_shapes(a)
if not (filter_tokens and a is abstract_token)])
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts, partitions_proto)
xla_inputs = iter(xla_destructure(c, tuple_param))
xla_args = [next(xla_inputs) if not (filter_tokens and a is abstract_token)
else xops.CreateToken(c) for a in avals]
assert next(xla_inputs, None) is None
[shape if not (filter_tokens and a is abstract_token)
else _token_param_shape()
for a in avals for shape in aval_to_xla_shapes(a)])
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts,
partitions_proto, filter_tokens)
xla_args = [v if not (filter_tokens and a is abstract_token)
else xops.CreateToken(c)
for a, v in zip(avals, xla_destructure(c, tuple_param))]
return xla_args, donated_invars
def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_proto):
def _xla_param(builder, param_num, xla_shape, replicated, partitions,
parts_proto, filter_tokens):
is_token = xla_shape.is_token()
if filter_tokens and is_token:
xla_shape = _token_param_shape()
make_param = partial(xb.parameter, builder, param_num, xla_shape,
replicated=replicated)
with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
if partitions is None:
return make_param()
out = make_param()
elif partitions is _replicated_param:
return with_sharding(builder, None, make_param)
out = with_sharding(builder, None, make_param)
else:
return with_sharding(builder, partitions, make_param)
out = with_sharding(builder, partitions, make_param)
if filter_tokens and is_token:
out = xops.CreateToken(builder)
return out
### compiling jaxprs

View File

@ -132,7 +132,8 @@ core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval
xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x
dispatch.device_put_handlers[SparseArray] = sparse_array_device_put_handler
dispatch.xla_result_handlers[AbstractSparseArray] = sparse_array_result_handler
dispatch.result_handlers[AbstractSparseArray] = sparse_array_result_handler
dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
@ -260,7 +261,8 @@ core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval
xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty()
xla.canonicalize_dtype_handlers[Empty] = lambda x: x
dispatch.device_put_handlers[Empty] = lambda _, __: ()
dispatch.xla_result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
dispatch.result_handlers[AbstractEmpty] = lambda _, __: lambda: Empty(AbstractEmpty())
dispatch.num_buffers_handlers[AbstractEmpty] = lambda _: 0
xla.xla_shape_handlers[AbstractEmpty] = lambda _: ()