Add experimental support for XLA infeed/outfeed.

This commit is contained in:
Peter Hawkins 2019-10-09 15:05:54 -04:00
parent 361adfe482
commit b8a5473614
4 changed files with 154 additions and 23 deletions

View File

@ -164,6 +164,11 @@ class ConcreteArray(ShapedArray):
return str(self.val)
class AbstractToken(core.AbstractValue): pass
abstract_token = AbstractToken()
def make_shaped_array(x):
dtype = xla_bridge.canonicalize_dtype(onp.result_type(x))
return ShapedArray(onp.shape(x), dtype)
@ -197,6 +202,8 @@ def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
elif aval is core.abstract_unit:
return core.abstract_unit
elif aval is abstract_token:
return abstract_token
else:
raise TypeError(type(aval))

View File

@ -208,7 +208,7 @@ def compile_replicated(jaxpr, backend, axis_name, axis_size, global_axis_size,
axis_env = xla.AxisEnv(num_replicas, [axis_name], [global_axis_size], devices)
arg_shapes = list(map(aval_to_xla_shape, abstract_args))
built_c = xla.jaxpr_computation(jaxpr, backend, axis_env, consts, (), arg_shapes,
tuple_args=tuple_args)
tuple_args=tuple_args, inner=False)
compiled = built_c.Compile(
compile_options=xb.get_compile_options(num_replicas, device_assignment),
backend=xb.get_backend(backend))

View File

@ -31,8 +31,9 @@ from .. import core
from .. import ad_util
from .. import tree_util
from .. import linear_util as lu
from ..abstract_arrays import (ConcreteArray, ShapedArray, make_shaped_array,
array_types, raise_to_shaped)
from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
make_shaped_array, array_types, raise_to_shaped,
abstract_token)
from ..core import valid_jaxtype, Literal
from ..util import partial, partialmethod, cache, safe_map, prod, unzip2
from ..lib import xla_bridge as xb
@ -152,7 +153,7 @@ def primitive_computation(prim, *xla_shapes, **params):
prim, (), params))
))
platform = xb.get_backend(backend).platform
xla_args = map(c.ParameterWithShape, xla_shapes)
xla_args = (_parameter_or_create_token(c, shape) for shape in xla_shapes)
if prim in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][prim]
rule(c, *xla_args, **new_params) # return val set as a side-effect on c
@ -178,7 +179,8 @@ def primitive_computation(prim, *xla_shapes, **params):
def _execute_compiled_primitive(prim, compiled, backend, result_handler, *args):
device_num, = compiled.DeviceOrdinals()
input_bufs = [device_put(x, device_num, backend=backend) for x in args]
input_bufs = [device_put(x, device_num, backend=backend) for x in args
if x is not token]
out_buf = compiled.Execute(input_bufs)
if FLAGS.jax_debug_nans:
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
@ -202,7 +204,7 @@ def _check_nans(name, xla_shape, buf):
### compiling jaxprs
def compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
def _compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
*abstract_args):
if axis_env.nreps > xb.device_count(backend):
msg = ("compiling computation that requires {} replicas, but only {} XLA "
@ -210,7 +212,7 @@ def compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
raise ValueError(msg.format(axis_env.nreps, xb.device_count(backend)))
arg_shapes = tuple(map(aval_to_xla_shape, abstract_args))
built_c = jaxpr_computation(jaxpr, backend, axis_env, const_vals, (), arg_shapes,
tuple_args=tuple_args)
tuple_args=tuple_args, inner=False)
device_assignment = (device.id,) if device else None
compile_opts = xb.get_compile_options(num_replicas=axis_env.nreps,
device_assignment=device_assignment)
@ -220,7 +222,7 @@ def compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
def build_jaxpr(jaxpr, backend, axis_env, const_vals, tuple_args, *abstract_args):
arg_shapes = map(aval_to_xla_shape, abstract_args)
return jaxpr_computation(jaxpr, backend, axis_env, const_vals, (), arg_shapes,
tuple_args=tuple_args)
tuple_args=tuple_args, inner=False)
def prefetch(x):
if isinstance(x, DeviceArray):
@ -245,21 +247,37 @@ def eqn_literals(eqn):
if type(v) is core.Literal:
yield v.val
def _parameter_or_create_token(c, shape):
# Token values cannot be passed as parameters to top-level computations.
# Instead, we manufacture a fresh token as an argument. However, tokens make
# sense and are allowed as arguments to inner computations.
if shape.xla_element_type() == xc.PrimitiveType.TOKEN:
return c.CreateToken()
else:
return c.ParameterWithShape(shape)
def jaxpr_computation(jaxpr, backend, axis_env, const_vals, freevar_shapes,
arg_shapes, tuple_args=False):
arg_shapes, tuple_args=False, inner=False):
# If inner is True, tokens can be passed as parameters; if inner is False,
# token parameters become CreateToken instructions.
c = xb.make_computation_builder("jaxpr_computation") # TODO(mattjj): name
_map(prefetch, it.chain(const_vals, jaxpr_literals(jaxpr)))
consts = _map(c.Constant, const_vals)
if tuple_args:
freevar_shapes, arg_shapes = list(freevar_shapes), list(arg_shapes)
tuple_shape = xc.Shape.tuple_shape(freevar_shapes + arg_shapes)
freevar_shapes = list(freevar_shapes)
arg_shapes = list(arg_shapes)
tuple_shape = xc.Shape.tuple_shape(
[s for s in it.chain(freevar_shapes, arg_shapes)
if s.xla_element_type() != xc.PrimitiveType.TOKEN])
tuple_arg = c.ParameterWithShape(tuple_shape)
nfreevars, nargs = len(freevar_shapes), len(arg_shapes)
freevars = [c.GetTupleElement(tuple_arg, i) for i in range(nfreevars)]
args = [c.GetTupleElement(tuple_arg, i + nfreevars) for i in range(nargs)]
else:
freevars = _map(c.ParameterWithShape, freevar_shapes)
args = _map(c.ParameterWithShape, arg_shapes)
make_parameter = (c.ParameterWithShape if inner
else partial(_parameter_or_create_token, c))
freevars = _map(make_parameter, freevar_shapes)
args = _map(make_parameter, arg_shapes)
out_nodes = jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, freevars, *args)
return c.Build(c.Tuple(*out_nodes))
@ -409,8 +427,8 @@ def _xla_callable(fun, device, backend, *abstract_args):
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
assert not env # no subtraces here (though cond might eventually need them)
axis_env = AxisEnv(jaxpr_replicas(jaxpr), [], [])
compiled = compile_jaxpr(jaxpr, device, backend, axis_env, consts,
tuple_args, *abstract_args)
compiled = _compile_jaxpr(jaxpr, device, backend, axis_env, consts,
tuple_args, *abstract_args)
del master, consts, jaxpr, env
result_handlers = tuple(map(_pval_to_result_handler, pvals))
if axis_env.nreps == 1:
@ -427,7 +445,8 @@ def _pval_to_result_handler(pval):
def _execute_compiled(compiled, backend, handlers, tuple_args, *args):
device_num, = compiled.DeviceOrdinals()
input_bufs = [device_put(x, device_num, backend=backend) for x in args]
input_bufs = [device_put(x, device_num, backend=backend) for x in args
if x is not token]
if tuple_args:
input_bufs = [make_tuple(input_bufs, device_num, backend)]
out_bufs = compiled.Execute(input_bufs).destructure()
@ -435,8 +454,9 @@ def _execute_compiled(compiled, backend, handlers, tuple_args, *args):
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
def _execute_replicated(compiled, backend, handlers, tuple_args, *args):
input_bufs = [[device_put(x, device_num, backend=backend) for x in args]
for device_num in compiled.DeviceOrdinals()]
input_bufs = [
[device_put(x, device_num, backend=backend) for x in args if x is not token]
for device_num in compiled.DeviceOrdinals()]
if tuple_args:
input_bufs = [[make_tuple(bufs, device_num)] for bufs, device_num in
zip(input_bufs, compiled.DeviceOrdinals())]
@ -512,7 +532,8 @@ def lower_fun(fun, instantiate=False, initial_style=False):
pvals = [pe.PartialVal((a, core.unit)) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(
lu.wrap_init(fun, new_params), pvals, instantiate=True)
built_c = jaxpr_computation(jaxpr, backend, axis_env, consts, (), xla_shapes)
built_c = jaxpr_computation(jaxpr, backend, axis_env, consts, (),
xla_shapes, inner=True)
return c.Call(built_c, xla_args)
return f
@ -525,6 +546,15 @@ def _aval_from_xla_shape(xla_shape):
### device-persistent data
class Token(object): pass
token = Token()
pytype_aval_mappings[Token] = lambda _: abstract_token
xla_shape_handlers[AbstractToken] = lambda _: xc.Shape.token_shape()
xla_result_handlers[AbstractToken] = lambda _: lambda _: token
canonicalize_dtype_handlers[Token] = lambda x: x
class DeviceValue(object):
"""A DeviceValue represents a value backed by device memory."""
__slots__ = ["aval", "device_buffer", "__weakref__"]

View File

@ -37,7 +37,8 @@ from .. import linear_util as lu
from ..config import flags
from ..core import Primitive
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
array_types, make_shaped_array, raise_to_shaped)
AbstractToken, array_types, make_shaped_array,
raise_to_shaped, abstract_token)
from ..interpreters import partial_eval as pe
from ..interpreters import xla
from ..interpreters import pxla
@ -47,6 +48,7 @@ from ..interpreters import masking
from ..interpreters.masking import ShapeExpr, ShapeError
from ..util import curry, cache, safe_zip, unzip2, prod
from ..tree_util import build_tree, tree_unflatten, tree_map
from ..lib import pytree
from ..lib import xla_bridge
from ..lib import xla_client
@ -3708,9 +3710,9 @@ def _select_and_gather_add_translation(
canonicalize_types=False)
if double_word_reduction:
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks.
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks.
word_dtype = _UINT_DTYPES[nbits]
double_word_dtype = _UINT_DTYPES[nbits * 2]
word_type = xla_client.dtype_to_etype(word_dtype)
@ -4054,6 +4056,98 @@ xla.translations[stop_gradient_p] = lambda c, x: x
ad.primitive_jvps[stop_gradient_p] = _stop_gradient_jvp_rule
batching.primitive_batchers[stop_gradient_p] = _stop_gradient_batch_rule
def create_token(x):
"""Creates an XLA token value with no preconditions for sequencing effects.
Experimental.
Args:
x: a dummy argument used to tie the CreateToken operator into a trace. The
value of `x` is ignored.
"""
# x is a dummy argument used to tie the operator into a trace.
return create_token_p.bind(x)
create_token_p = Primitive("create_token")
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
create_token_p.def_abstract_eval(lambda _: abstract_token)
xla.translations[create_token_p] = lambda c, _: c.CreateToken()
def after_all(*operands):
"""Merges one or more XLA token values. Experimental.
Wraps the XLA AfterAll operator."""
return after_all_p.bind(*operands)
def _after_all_abstract_eval(*operands):
if any(x is not abstract_token for x in operands):
raise TypeError("Arguments to after_all must be tokens")
return abstract_token
def _after_all_translation_rule(c, *operands):
return c.AfterAll(operands)
after_all_p = Primitive("after_all")
after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
after_all_p.def_abstract_eval(_after_all_abstract_eval)
xla.translations[after_all_p] = _after_all_translation_rule
def infeed(token, shape=None):
"""Consumes an infeed value of `shape` from the host. Experimental.
`token` is used to sequence infeed and outfeed effects.
"""
flat_shapes, treedef = pytree.flatten(shape)
for shape in flat_shapes:
if not isinstance(shape, ShapedArray):
raise TypeError("shapes argument to infeed must be a pytree of "
"ShapedArray values, got {}".format(shapes))
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes))
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
def _infeed_abstract_eval(token, shapes=None):
if token is not abstract_token:
raise TypeError("First argument to infeed must be a token")
return shapes + (abstract_token,)
def _infeed_translation_rule(c, token, shapes=None):
shape = tuple(map(xla.aval_to_xla_shape, shapes))
xs_and_token = c.Infeed(xla_client.Shape.tuple_shape(shape), token)
xs = c.GetTupleElement(xs_and_token, 0)
token = c.GetTupleElement(xs_and_token, 1)
outs = [c.GetTupleElement(xs, i) for i in range(len(shapes))] + [token]
return c.Tuple(*outs)
infeed_p = Primitive("infeed")
infeed_p.multiple_results = True
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
infeed_p.def_abstract_eval(_infeed_abstract_eval)
xla.translations[infeed_p] = _infeed_translation_rule
def outfeed(token, xs):
"""Outfeeds value `xs` to the host. Experimental.
`token` is used to sequence infeed and outfeed effects.
"""
flat_xs, _ = pytree.flatten(xs)
return outfeed_p.bind(token, *flat_xs)
def _outfeed_abstract_eval(token, *xs):
if token is not abstract_token:
raise TypeError("First argument to outfeed must be a token")
return abstract_token
def _outfeed_translation_rule(c, token, *xs):
return c.Outfeed(c.Tuple(*xs), token)
outfeed_p = Primitive("outfeed")
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
xla.translations[outfeed_p] = _outfeed_translation_rule
### util