mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add experimental support for XLA infeed/outfeed.
This commit is contained in:
parent
361adfe482
commit
b8a5473614
@ -164,6 +164,11 @@ class ConcreteArray(ShapedArray):
|
||||
return str(self.val)
|
||||
|
||||
|
||||
class AbstractToken(core.AbstractValue): pass
|
||||
|
||||
abstract_token = AbstractToken()
|
||||
|
||||
|
||||
def make_shaped_array(x):
|
||||
dtype = xla_bridge.canonicalize_dtype(onp.result_type(x))
|
||||
return ShapedArray(onp.shape(x), dtype)
|
||||
@ -197,6 +202,8 @@ def raise_to_shaped(aval):
|
||||
return ShapedArray(aval.shape, aval.dtype)
|
||||
elif aval is core.abstract_unit:
|
||||
return core.abstract_unit
|
||||
elif aval is abstract_token:
|
||||
return abstract_token
|
||||
else:
|
||||
raise TypeError(type(aval))
|
||||
|
||||
|
@ -208,7 +208,7 @@ def compile_replicated(jaxpr, backend, axis_name, axis_size, global_axis_size,
|
||||
axis_env = xla.AxisEnv(num_replicas, [axis_name], [global_axis_size], devices)
|
||||
arg_shapes = list(map(aval_to_xla_shape, abstract_args))
|
||||
built_c = xla.jaxpr_computation(jaxpr, backend, axis_env, consts, (), arg_shapes,
|
||||
tuple_args=tuple_args)
|
||||
tuple_args=tuple_args, inner=False)
|
||||
compiled = built_c.Compile(
|
||||
compile_options=xb.get_compile_options(num_replicas, device_assignment),
|
||||
backend=xb.get_backend(backend))
|
||||
|
@ -31,8 +31,9 @@ from .. import core
|
||||
from .. import ad_util
|
||||
from .. import tree_util
|
||||
from .. import linear_util as lu
|
||||
from ..abstract_arrays import (ConcreteArray, ShapedArray, make_shaped_array,
|
||||
array_types, raise_to_shaped)
|
||||
from ..abstract_arrays import (ConcreteArray, ShapedArray, AbstractToken,
|
||||
make_shaped_array, array_types, raise_to_shaped,
|
||||
abstract_token)
|
||||
from ..core import valid_jaxtype, Literal
|
||||
from ..util import partial, partialmethod, cache, safe_map, prod, unzip2
|
||||
from ..lib import xla_bridge as xb
|
||||
@ -152,7 +153,7 @@ def primitive_computation(prim, *xla_shapes, **params):
|
||||
prim, (), params))
|
||||
))
|
||||
platform = xb.get_backend(backend).platform
|
||||
xla_args = map(c.ParameterWithShape, xla_shapes)
|
||||
xla_args = (_parameter_or_create_token(c, shape) for shape in xla_shapes)
|
||||
if prim in backend_specific_translations[platform]:
|
||||
rule = backend_specific_translations[platform][prim]
|
||||
rule(c, *xla_args, **new_params) # return val set as a side-effect on c
|
||||
@ -178,7 +179,8 @@ def primitive_computation(prim, *xla_shapes, **params):
|
||||
|
||||
def _execute_compiled_primitive(prim, compiled, backend, result_handler, *args):
|
||||
device_num, = compiled.DeviceOrdinals()
|
||||
input_bufs = [device_put(x, device_num, backend=backend) for x in args]
|
||||
input_bufs = [device_put(x, device_num, backend=backend) for x in args
|
||||
if x is not token]
|
||||
out_buf = compiled.Execute(input_bufs)
|
||||
if FLAGS.jax_debug_nans:
|
||||
check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)
|
||||
@ -202,7 +204,7 @@ def _check_nans(name, xla_shape, buf):
|
||||
|
||||
### compiling jaxprs
|
||||
|
||||
def compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
|
||||
def _compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
|
||||
*abstract_args):
|
||||
if axis_env.nreps > xb.device_count(backend):
|
||||
msg = ("compiling computation that requires {} replicas, but only {} XLA "
|
||||
@ -210,7 +212,7 @@ def compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
|
||||
raise ValueError(msg.format(axis_env.nreps, xb.device_count(backend)))
|
||||
arg_shapes = tuple(map(aval_to_xla_shape, abstract_args))
|
||||
built_c = jaxpr_computation(jaxpr, backend, axis_env, const_vals, (), arg_shapes,
|
||||
tuple_args=tuple_args)
|
||||
tuple_args=tuple_args, inner=False)
|
||||
device_assignment = (device.id,) if device else None
|
||||
compile_opts = xb.get_compile_options(num_replicas=axis_env.nreps,
|
||||
device_assignment=device_assignment)
|
||||
@ -220,7 +222,7 @@ def compile_jaxpr(jaxpr, device, backend, axis_env, const_vals, tuple_args,
|
||||
def build_jaxpr(jaxpr, backend, axis_env, const_vals, tuple_args, *abstract_args):
|
||||
arg_shapes = map(aval_to_xla_shape, abstract_args)
|
||||
return jaxpr_computation(jaxpr, backend, axis_env, const_vals, (), arg_shapes,
|
||||
tuple_args=tuple_args)
|
||||
tuple_args=tuple_args, inner=False)
|
||||
|
||||
def prefetch(x):
|
||||
if isinstance(x, DeviceArray):
|
||||
@ -245,21 +247,37 @@ def eqn_literals(eqn):
|
||||
if type(v) is core.Literal:
|
||||
yield v.val
|
||||
|
||||
def _parameter_or_create_token(c, shape):
|
||||
# Token values cannot be passed as parameters to top-level computations.
|
||||
# Instead, we manufacture a fresh token as an argument. However, tokens make
|
||||
# sense and are allowed as arguments to inner computations.
|
||||
if shape.xla_element_type() == xc.PrimitiveType.TOKEN:
|
||||
return c.CreateToken()
|
||||
else:
|
||||
return c.ParameterWithShape(shape)
|
||||
|
||||
def jaxpr_computation(jaxpr, backend, axis_env, const_vals, freevar_shapes,
|
||||
arg_shapes, tuple_args=False):
|
||||
arg_shapes, tuple_args=False, inner=False):
|
||||
# If inner is True, tokens can be passed as parameters; if inner is False,
|
||||
# token parameters become CreateToken instructions.
|
||||
c = xb.make_computation_builder("jaxpr_computation") # TODO(mattjj): name
|
||||
_map(prefetch, it.chain(const_vals, jaxpr_literals(jaxpr)))
|
||||
consts = _map(c.Constant, const_vals)
|
||||
if tuple_args:
|
||||
freevar_shapes, arg_shapes = list(freevar_shapes), list(arg_shapes)
|
||||
tuple_shape = xc.Shape.tuple_shape(freevar_shapes + arg_shapes)
|
||||
freevar_shapes = list(freevar_shapes)
|
||||
arg_shapes = list(arg_shapes)
|
||||
tuple_shape = xc.Shape.tuple_shape(
|
||||
[s for s in it.chain(freevar_shapes, arg_shapes)
|
||||
if s.xla_element_type() != xc.PrimitiveType.TOKEN])
|
||||
tuple_arg = c.ParameterWithShape(tuple_shape)
|
||||
nfreevars, nargs = len(freevar_shapes), len(arg_shapes)
|
||||
freevars = [c.GetTupleElement(tuple_arg, i) for i in range(nfreevars)]
|
||||
args = [c.GetTupleElement(tuple_arg, i + nfreevars) for i in range(nargs)]
|
||||
else:
|
||||
freevars = _map(c.ParameterWithShape, freevar_shapes)
|
||||
args = _map(c.ParameterWithShape, arg_shapes)
|
||||
make_parameter = (c.ParameterWithShape if inner
|
||||
else partial(_parameter_or_create_token, c))
|
||||
freevars = _map(make_parameter, freevar_shapes)
|
||||
args = _map(make_parameter, arg_shapes)
|
||||
out_nodes = jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, freevars, *args)
|
||||
return c.Build(c.Tuple(*out_nodes))
|
||||
|
||||
@ -409,8 +427,8 @@ def _xla_callable(fun, device, backend, *abstract_args):
|
||||
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
|
||||
assert not env # no subtraces here (though cond might eventually need them)
|
||||
axis_env = AxisEnv(jaxpr_replicas(jaxpr), [], [])
|
||||
compiled = compile_jaxpr(jaxpr, device, backend, axis_env, consts,
|
||||
tuple_args, *abstract_args)
|
||||
compiled = _compile_jaxpr(jaxpr, device, backend, axis_env, consts,
|
||||
tuple_args, *abstract_args)
|
||||
del master, consts, jaxpr, env
|
||||
result_handlers = tuple(map(_pval_to_result_handler, pvals))
|
||||
if axis_env.nreps == 1:
|
||||
@ -427,7 +445,8 @@ def _pval_to_result_handler(pval):
|
||||
|
||||
def _execute_compiled(compiled, backend, handlers, tuple_args, *args):
|
||||
device_num, = compiled.DeviceOrdinals()
|
||||
input_bufs = [device_put(x, device_num, backend=backend) for x in args]
|
||||
input_bufs = [device_put(x, device_num, backend=backend) for x in args
|
||||
if x is not token]
|
||||
if tuple_args:
|
||||
input_bufs = [make_tuple(input_bufs, device_num, backend)]
|
||||
out_bufs = compiled.Execute(input_bufs).destructure()
|
||||
@ -435,8 +454,9 @@ def _execute_compiled(compiled, backend, handlers, tuple_args, *args):
|
||||
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
|
||||
|
||||
def _execute_replicated(compiled, backend, handlers, tuple_args, *args):
|
||||
input_bufs = [[device_put(x, device_num, backend=backend) for x in args]
|
||||
for device_num in compiled.DeviceOrdinals()]
|
||||
input_bufs = [
|
||||
[device_put(x, device_num, backend=backend) for x in args if x is not token]
|
||||
for device_num in compiled.DeviceOrdinals()]
|
||||
if tuple_args:
|
||||
input_bufs = [[make_tuple(bufs, device_num)] for bufs, device_num in
|
||||
zip(input_bufs, compiled.DeviceOrdinals())]
|
||||
@ -512,7 +532,8 @@ def lower_fun(fun, instantiate=False, initial_style=False):
|
||||
pvals = [pe.PartialVal((a, core.unit)) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(
|
||||
lu.wrap_init(fun, new_params), pvals, instantiate=True)
|
||||
built_c = jaxpr_computation(jaxpr, backend, axis_env, consts, (), xla_shapes)
|
||||
built_c = jaxpr_computation(jaxpr, backend, axis_env, consts, (),
|
||||
xla_shapes, inner=True)
|
||||
return c.Call(built_c, xla_args)
|
||||
return f
|
||||
|
||||
@ -525,6 +546,15 @@ def _aval_from_xla_shape(xla_shape):
|
||||
|
||||
### device-persistent data
|
||||
|
||||
class Token(object): pass
|
||||
token = Token()
|
||||
|
||||
pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
xla_shape_handlers[AbstractToken] = lambda _: xc.Shape.token_shape()
|
||||
xla_result_handlers[AbstractToken] = lambda _: lambda _: token
|
||||
canonicalize_dtype_handlers[Token] = lambda x: x
|
||||
|
||||
|
||||
class DeviceValue(object):
|
||||
"""A DeviceValue represents a value backed by device memory."""
|
||||
__slots__ = ["aval", "device_buffer", "__weakref__"]
|
||||
|
102
jax/lax/lax.py
102
jax/lax/lax.py
@ -37,7 +37,8 @@ from .. import linear_util as lu
|
||||
from ..config import flags
|
||||
from ..core import Primitive
|
||||
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
array_types, make_shaped_array, raise_to_shaped)
|
||||
AbstractToken, array_types, make_shaped_array,
|
||||
raise_to_shaped, abstract_token)
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import pxla
|
||||
@ -47,6 +48,7 @@ from ..interpreters import masking
|
||||
from ..interpreters.masking import ShapeExpr, ShapeError
|
||||
from ..util import curry, cache, safe_zip, unzip2, prod
|
||||
from ..tree_util import build_tree, tree_unflatten, tree_map
|
||||
from ..lib import pytree
|
||||
from ..lib import xla_bridge
|
||||
from ..lib import xla_client
|
||||
|
||||
@ -3708,9 +3710,9 @@ def _select_and_gather_add_translation(
|
||||
canonicalize_types=False)
|
||||
|
||||
if double_word_reduction:
|
||||
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
|
||||
# we implement a pair-wise ReduceWindow by packing two k-bit values into
|
||||
# 2k-bit unsigned integer using bit tricks.
|
||||
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
|
||||
# we implement a pair-wise ReduceWindow by packing two k-bit values into
|
||||
# 2k-bit unsigned integer using bit tricks.
|
||||
word_dtype = _UINT_DTYPES[nbits]
|
||||
double_word_dtype = _UINT_DTYPES[nbits * 2]
|
||||
word_type = xla_client.dtype_to_etype(word_dtype)
|
||||
@ -4054,6 +4056,98 @@ xla.translations[stop_gradient_p] = lambda c, x: x
|
||||
ad.primitive_jvps[stop_gradient_p] = _stop_gradient_jvp_rule
|
||||
batching.primitive_batchers[stop_gradient_p] = _stop_gradient_batch_rule
|
||||
|
||||
def create_token(x):
|
||||
"""Creates an XLA token value with no preconditions for sequencing effects.
|
||||
|
||||
Experimental.
|
||||
|
||||
Args:
|
||||
x: a dummy argument used to tie the CreateToken operator into a trace. The
|
||||
value of `x` is ignored.
|
||||
"""
|
||||
# x is a dummy argument used to tie the operator into a trace.
|
||||
return create_token_p.bind(x)
|
||||
|
||||
create_token_p = Primitive("create_token")
|
||||
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
|
||||
create_token_p.def_abstract_eval(lambda _: abstract_token)
|
||||
xla.translations[create_token_p] = lambda c, _: c.CreateToken()
|
||||
|
||||
def after_all(*operands):
|
||||
"""Merges one or more XLA token values. Experimental.
|
||||
|
||||
Wraps the XLA AfterAll operator."""
|
||||
return after_all_p.bind(*operands)
|
||||
|
||||
def _after_all_abstract_eval(*operands):
|
||||
if any(x is not abstract_token for x in operands):
|
||||
raise TypeError("Arguments to after_all must be tokens")
|
||||
return abstract_token
|
||||
|
||||
|
||||
def _after_all_translation_rule(c, *operands):
|
||||
return c.AfterAll(operands)
|
||||
|
||||
after_all_p = Primitive("after_all")
|
||||
after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
|
||||
after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
||||
xla.translations[after_all_p] = _after_all_translation_rule
|
||||
|
||||
|
||||
def infeed(token, shape=None):
|
||||
"""Consumes an infeed value of `shape` from the host. Experimental.
|
||||
|
||||
`token` is used to sequence infeed and outfeed effects.
|
||||
"""
|
||||
flat_shapes, treedef = pytree.flatten(shape)
|
||||
for shape in flat_shapes:
|
||||
if not isinstance(shape, ShapedArray):
|
||||
raise TypeError("shapes argument to infeed must be a pytree of "
|
||||
"ShapedArray values, got {}".format(shapes))
|
||||
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes))
|
||||
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
|
||||
|
||||
def _infeed_abstract_eval(token, shapes=None):
|
||||
if token is not abstract_token:
|
||||
raise TypeError("First argument to infeed must be a token")
|
||||
return shapes + (abstract_token,)
|
||||
|
||||
|
||||
def _infeed_translation_rule(c, token, shapes=None):
|
||||
shape = tuple(map(xla.aval_to_xla_shape, shapes))
|
||||
xs_and_token = c.Infeed(xla_client.Shape.tuple_shape(shape), token)
|
||||
xs = c.GetTupleElement(xs_and_token, 0)
|
||||
token = c.GetTupleElement(xs_and_token, 1)
|
||||
outs = [c.GetTupleElement(xs, i) for i in range(len(shapes))] + [token]
|
||||
return c.Tuple(*outs)
|
||||
|
||||
infeed_p = Primitive("infeed")
|
||||
infeed_p.multiple_results = True
|
||||
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
|
||||
infeed_p.def_abstract_eval(_infeed_abstract_eval)
|
||||
xla.translations[infeed_p] = _infeed_translation_rule
|
||||
|
||||
def outfeed(token, xs):
|
||||
"""Outfeeds value `xs` to the host. Experimental.
|
||||
|
||||
`token` is used to sequence infeed and outfeed effects.
|
||||
"""
|
||||
flat_xs, _ = pytree.flatten(xs)
|
||||
return outfeed_p.bind(token, *flat_xs)
|
||||
|
||||
def _outfeed_abstract_eval(token, *xs):
|
||||
if token is not abstract_token:
|
||||
raise TypeError("First argument to outfeed must be a token")
|
||||
return abstract_token
|
||||
|
||||
|
||||
def _outfeed_translation_rule(c, token, *xs):
|
||||
return c.Outfeed(c.Tuple(*xs), token)
|
||||
|
||||
outfeed_p = Primitive("outfeed")
|
||||
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
|
||||
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
|
||||
xla.translations[outfeed_p] = _outfeed_translation_rule
|
||||
|
||||
### util
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user