rocm_jax/jax/experimental/host_callback.py

798 lines
31 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Primitives for calling user-defined Python functions
on the host, even from compiled and transformed code.
**Experimental: please give feedback, and expect changes.**
Tapping works even for code executed on accelerators and
even for code under JAX transformations. A few examples::
# calls func(2x) and returns 2x
y = id_tap(func, x * 2)
# calls func((2x, 3x)) and returns (2x, 3x)
y, z = id_tap(func, (x * 2, x * 3)) # The argument can be a pytree
# calls func(2x) and returns y
y = id_tap(func, x * 2, result=y)
# calls func(2x, what='x') and returns 2x
y = id_tap(func, x * 2, what='x')
# calls func(dict(x=x, y=y), what='foo') and returns dict(x=x, y=y)
x, y = id_tap(func, dict(x=x, y=y), what='a dict')
The order of execution is by data dependency: after all the arguments are
computed and before the result is used. At least one of the returned values
must be used in the rest of the computation, or else this operation has no effect.
*At the moment*, in order to use the callback primitives in compiled code,
one must wrap any invocation with an :func:`outfeed_receiver` (an exception is
raised otherwise)::
with outfeed_receiver():
...calls to compiled code that may invoke callback primitives...
The printing and the tap functions execute in separate threads that are started
by :func:`outfeed_receiver`. Exceptions from the primitives are printed along with
the traceback, but the execution does not stop until the body of the
:func:`outfeed_receiver` terminates and all the outfeeds are received.
At that point, a ``TapFunctionException`` is raised.
We describe the behaviour under transformations in the context of the
following function definition::
def power3(x):
y = x * x
_, y = id_print(x, y, what="x,x^2")
return y * x
For :func:`jax.vmap` the arguments are batched, ``vmap`` is appended
to ``transforms``, and `batch_dims` is added to specify the tuple of
batched dimensions::
jax.vmap(power3)(np.arange(3.))
# what=x,x^2 transforms=vmap batch_dims=(0, 0): ([0, 1, 2], [0, 1, 4])
For :func:`jax.jvp` there will be two callbacks, one with the values of the primals and
one with the tangents::
jax.jvp(power3, (3.,), (0.1,))
# what=x,x^2: (3., 9.)
# what=x,x^2 transforms=jvp: (0.1, 0.6)
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the values of
the adjoints for the arguments. You may also see a callback with the values of
the primals from the forward pass, if those values are needed for the
backward pass::
jax.grad(power3)(3.)
# what=x,x^2: (3., 9.) # from forward pass, since y is needed in backward pass
# what=x,x^2 transforms=(jvp, transpose): (0., 3.) # from backward pass, adjoints of _, y
See documentation for :func:`id_tap` and :func:`id_print`.
For usage example, see tests/host_callback_test.py.
Still to do:
* Performance tests.
* Add flags for logging.
* Add unit tests with mocks.
* Improve the XLA compilation code.
* Improve the ergonomics of starting the consumer loop. Currently, when
invoking jit-ed code, one must start a consumer loop. This is not needed
when invoking code that does not involve jit. There is an error when
attempting to start a compiled computation without starting the outfeed
receiver. Perhaps we can put the receiver threads in the runtime.
* Explore a simpler API that uses Python program-order, instead of
data dependency-order. Need to add support to JAX for stateful primitives.
* Explore implementation with outside compilation.
"""
from collections import defaultdict, namedtuple
from concurrent import futures
from contextlib import contextmanager
from functools import partial
import io
import itertools
from jax import abstract_arrays
from jax import ad_util
from jax import api
from jax import core
from jax import dtypes
from jax import lax
from jax.lib import pytree, xla_bridge
from jax.interpreters import ad, xla, batching, masking, pxla
from jax.interpreters import partial_eval as pe
from jax import pprint_util as ppu
from jax import util
from jaxlib import xla_client
from jaxlib import xla_extension
import logging
import msgpack # type: ignore
import numpy as onp
import sys
import traceback
from typing import Any, Callable, Dict, Iterable, List, Optional, NamedTuple, Sequence, Tuple
xops = xla_client._xla.ops
# TODO(necula): fix mypy errors if I define the type aliases below
XlaOp = Any # xla_extension.XlaOp
XlaShape = Any # xla_client.Shape
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
XlaDevice = Any # xla_client.Device
# TODO: add a flag
_LOGGING = True
def id_tap(func: Callable, arg, *,
result=None,
**kwargs):
"""Host-callback tap primitive, like identity function with a call to ``func``.
**Experimental: please give feedback, and expect changes!**
``id_tap`` behaves semantically like the identity function but has the side-effect
that a user-defined Python function is called with the runtime values of the
argument.
Args:
* arg: the argument passed to the tap function, can be a pytree of JAX
types.
* result: if given, then specifies the return value of ``id_tap``. By default,
the return type is ``arg``.
* kwargs: will be passed directly to the tap function. Can be anything,
these are kept in the host Python process.
Returns:
* the value of ``result`` or otherwise ``arg``
Tapping works even for code executed on accelerators and
even for code under JAX transformations.
For more details see the `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
"""
if func not in (_end_consumer, _unknown_consumer):
api._check_callable(func)
flat_args, arg_treedef = pytree.flatten(arg)
api._check_args(flat_args)
params = dict(kwargs) # we pass a copy of params to the primitive
# See definition of id_tap_p for what parameters it takes
params["func"] = func
params["arg_treedef"] = arg_treedef
if result is not None:
flat_results, result_treedef = pytree.flatten(result)
api._check_args(flat_results)
params["nr_untapped"] = len(flat_results)
all_args = flat_args + flat_results
else:
all_args = flat_args
flat_outs = id_tap_p.bind(*all_args, **params) # Always a tuple of all args
if result is not None:
return result_treedef.unflatten(flat_outs[-params["nr_untapped"]:]) # type: ignore[unsupported-operands]
else:
return arg_treedef.unflatten(flat_outs)
def id_print(arg, *, result=None, output_stream=None, threshold=None,
**kwargs):
"""Like :func:`id_tap` with a printing tap function.
**Experimental: please give feedback, and expect changes!**
On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed,
with the arrays stringified with ``numpy.array2string``.
See the :func:`id_tap` documentation.
Additional keyword arguments:
* ``output_stream`` if given then it will be used instead of the
built-in ``print``. The string will be passed as ``output_stream.write(s)``.
* ``threshold`` is passed to ``numpy.array2string``.
"""
return id_tap(_print_consumer, arg,
result=result, output_stream=output_stream,
threshold=threshold, **kwargs)
# A registry of outfeed consumers
class _ConsumerCallable(NamedTuple):
"""Host-side information for a outfeed consumer."""
func: Callable
kwargs: Tuple[Tuple[str, Any], ...]
arg_treedef: Any
_consumer_registry: Dict[_ConsumerCallable, int] = dict()
_consumer_registry_by_id: Dict[int, _ConsumerCallable] = dict()
def _register_consumer(cons: _ConsumerCallable) -> int:
"""Registers a tap function, cache by function identity"""
cons_id = _consumer_registry.get(cons)
if cons_id is not None:
return cons_id
cons_id = id(cons)
_consumer_registry[cons] = cons_id
_consumer_registry_by_id[cons_id] = cons
return cons_id
def _print_consumer(arg, *, output_stream=None,
threshold=1024, **kwargs):
"""The consumer for id_print"""
def emit_str(s: str):
if output_stream is not None:
output_stream.write(s + "\n")
else:
print(s)
kv_pairs = " ".join([f"{k}: {v}"
for k, v in sorted(kwargs.items())
if k not in ("consumer_id", "nr_untapped")])
if kv_pairs:
emit_str(kv_pairs)
def pp_val(arg) -> ppu.PrettyPrint:
if isinstance(arg, (tuple, list)):
return (ppu.pp('[ ') >>
ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(' ]'))
elif isinstance(arg, dict):
return (ppu.pp('{ ') >>
ppu.vcat(
[ppu.pp(f"{k}=") >> pp_val(v)
for k, v in sorted(arg.items())]) >>
ppu.pp(' }'))
elif isinstance(arg, onp.ndarray):
return ppu.pp(onp.array2string(arg, threshold=threshold))
else:
return ppu.pp(str(arg))
emit_str(str(pp_val(arg)))
"""The id_tap primitive acts like the identity function. It has a number of
positional arguments and parameters:
* func: the actual (Python) function to invoke with the positional arguments
and the parameters.
* nr_untapped: how many positional arguments (from the tail) should not be
passed to the tap function.
* arg_treedef: the treedef of the tapped positional arguments
* transforms: a tuple of the transformations that have been applied.
* batch_dims: a tuple of the dims that have been batched, for vmap
* logical_shapes: a tuple of evaluated logical shapes, for mask
* the remaining parameters are passed to the tap function.
"""
# TODO: handle multiple vmap and mask
id_tap_p = core.Primitive("id_tap")
id_tap_p.multiple_results = True
def _add_transform_name(params: Dict, transform: str) -> Dict:
"""Adds the `transform` to the params["transforms"]."""
return dict(params, transforms=params.get("transforms", ()) + (transform,))
def _id_tap_impl(*arrays, func=None, nr_untapped=0, arg_treedef=None,
**params):
assert isinstance(func, Callable)
func_params = dict(params)
# TODO: consolidate logic with the outfeed receiver
try:
assert nr_untapped <= len(arrays)
func_arrays = arrays[:-nr_untapped] if nr_untapped > 0 else arrays
arg = api.tree_unflatten(arg_treedef, func_arrays)
func(arg, **func_params)
except Exception as e:
raise TapFunctionException from e
# We continue for now, we need to keep reading the outfeed
return arrays # return all
id_tap_p.def_impl(_id_tap_impl)
def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \
-> Sequence[pe.AbstractValue]:
return args_a
id_tap_p.def_abstract_eval(_id_tap_abstract_eval)
# TODO: there must be a better way to do this.
# The AttributeError is for regular values, the KeyError is for ConcreteArray
def _instantiate_zeros(arg, tan):
"""Turn special ad.zero tangents into arrays of 0s."""
if tan is not ad.zero:
return tan
else:
try:
aval = arg.aval
return ad.instantiate_zeros_aval(aval, tan)
except (AttributeError, KeyError):
# We get here for regular Python values
return ad.zeros_like_jaxval(arg)
def _id_tap_jvp_rule(primals, tangents, *, func, nr_untapped=0, **params):
# Put primals through id_tap separately, so that partial evaluation
# can do its job for grad
out_primals = id_tap_p.bind(*primals, func=func, nr_untapped=nr_untapped, **params)
# Add one primal output as untapped, to create dependency.
tangent_zeros = tuple(map(_instantiate_zeros, primals, tangents))
out_tangents_extra = id_tap_p.bind(*tangent_zeros, out_primals[0],
func=func, nr_untapped=nr_untapped + 1,
**_add_transform_name(params, "jvp"))
return tuple(out_primals), tuple(out_tangents_extra[:-1])
ad.primitive_jvps[id_tap_p] = _id_tap_jvp_rule
def _id_tap_transpose_rule(cts, *args, func=None, nr_untapped=0, **params):
assert len(cts) == len(args)
cts_zeros = tuple(map(_instantiate_zeros, args, cts))
ct_args = id_tap_p.bind(*cts_zeros, func=func, nr_untapped=nr_untapped,
**_add_transform_name(params, "transpose"))
return ct_args
ad.primitive_transposes[id_tap_p] = _id_tap_transpose_rule
def _id_tap_batching_rule(batched_args, batch_dims, **params):
new_params = _add_transform_name(params, "batch")
new_params["batch_dims"] = batch_dims
res = id_tap_p.bind(*batched_args, **new_params)
return res, batch_dims
batching.primitive_batchers[id_tap_p] = _id_tap_batching_rule
# def _id_tap_shape_rule(*operands, **params):
# return tuple([op.shape for op in operands])
# TODO: these disappeared
# masking.shape_rules[id_tap_p] = _id_tap_shape_rule # type: ignore[module-attr]
def _id_tap_masking_rule(operands, operands_logical_shapes, **params):
new_params = _add_transform_name(params, "mask")
new_params["logical_shapes"] = operands_logical_shapes
return id_tap_p.bind(*operands, **new_params)
masking.masking_rules[id_tap_p] = _id_tap_masking_rule
#### XLA compilation ####
####
# Special consumer to mark the end of outfeed stream for a device
_end_consumer = 0
_unknown_consumer = 1 # for testing error cases
def _id_print_translation_rule_outfeed(comp: XlaComputationBuilder,
*args_op: XlaOp, func=None,
nr_untapped=0, arg_treedef=None,
**params):
params = dict(params)
if func is _end_consumer:
params["consumer_id"] = _end_consumer
elif func is _unknown_consumer:
params["consumer_id"] = _unknown_consumer # Will trigger an error, for testing
else:
params["consumer_id"] = _register_consumer(
_ConsumerCallable(func, tuple(params.items()), arg_treedef))
# We expect the current token at the end
current_token = args_op[-1]
current_token_shape = comp.GetShape(current_token)
if current_token_shape.is_array():
# TODO: we get here because wehn we partially eval some primitives
# we impl themn with JIT, but we did not rewrite them
has_token = False
current_token = xops.CreateToken(comp)
else:
has_token = True
nr_args_to_emit = len(args_op) - nr_untapped - (1 if has_token else 0)
next_token = _emit_outfeed(comp, current_token,
args_op[0:nr_args_to_emit], params["consumer_id"])
results = (args_op[:-1] + (next_token,)) if has_token else args_op
return xops.Tuple(comp, results)
xla.translations[id_tap_p] = _id_print_translation_rule_outfeed
# The data on the outfeed follows a protocol that allows multiplexing the
# outfeed among multiple consumers, and communicates in-stream shape and
# type of the data.
# Each batch of array data is preceeded by a header message, of type
# uint32[_OUTFEED_HEADER_LENGTH]:
# [0]: special header value 2178
# [1, 2]: a consumer id (64-bits, big-endian encoding as uint32[2]). The
# consumer id encodes the tap function (by id), the
# descriptor of the arrays to be outfed, and the kwargs (a sorted tuple
# of keys and values).
# [3]: the metadata length in bytes. The metadata is a msgpack-encoded value of type:
# [ (type_code, (d0, d1, ...)), ...] # for each array, element type code
# # and the dimensions.
# padded with 0s to _OUTFEED_HEADER_LENGTH
#
#
_OUTFEED_HEADER_LENGTH = 32 # In uint32 words
_OUTFEED_HEADER_START = 2178 # [0]
# consumer_id [1, 2]
# metadata_length in bytes [3]
_OUTFEED_HEADER_METADATA_LENGTH = 4 * (_OUTFEED_HEADER_LENGTH - 4)
_CODE_TO_DTYPE = {
0: onp.dtype(onp.int8),
1: onp.dtype(onp.int16),
2: onp.dtype(onp.int32),
3: onp.dtype(onp.int64),
4: onp.dtype(onp.uint8),
5: onp.dtype(onp.uint16),
6: onp.dtype(onp.uint32),
7: onp.dtype(onp.uint64),
8: onp.dtype(onp.float16),
9: onp.dtype(onp.float32),
10: onp.dtype(onp.float64),
11: onp.dtype(dtypes.bfloat16),
}
_DTYPE_STR_TO_CODE = dict([(str(d), c) for c, d in _CODE_TO_DTYPE.items()])
def _emit_outfeed(comp: XlaComputationBuilder, token: XlaOp,
arrays: Sequence[XlaOp], consumer_id: int) -> XlaOp:
"""Emits the arrays to the outfeed for the current device.
The kwargs must have at least "consumer_id" key.
"""
arrays_shape = [comp.GetShape(a) for a in arrays]
def _array_shape_to_tuple(a_shape: XlaShape):
# (element_type_code, (d0, d1, ..., dn))
return (_DTYPE_STR_TO_CODE[str(onp.dtype(a_shape.element_type()))],
a_shape.dimensions())
metadata = msgpack.dumps(tuple(map(_array_shape_to_tuple, arrays_shape)))
metadata_len = len(metadata)
if metadata_len > _OUTFEED_HEADER_METADATA_LENGTH:
# TODO: configurable
raise ValueError("Outfeed metadata too long")
metadata += b" " * (((metadata_len + 3) // 4) * 4 - metadata_len) # pad
header = ((_OUTFEED_HEADER_START,
(consumer_id >> 32) & 0xffffffff, (consumer_id & 0xffffffff),
metadata_len) +
tuple([int.from_bytes(metadata[i:i+4], byteorder="big")
for i in range(0, _OUTFEED_HEADER_METADATA_LENGTH, 4)]))
header += (0,) * (_OUTFEED_HEADER_LENGTH - len(header))
data = xops.ConstantLiteral(comp, onp.array(header, dtype=onp.uint32))
token = xops.OutfeedWithToken(data, token, comp.GetShape(data))
# Now send the arrays
entire_shape = xla_client.Shape.tuple_shape(arrays_shape)
token = xops.OutfeedWithToken(xops.Tuple(comp, arrays), token, entire_shape)
return token
def _receive_outfeed(device: XlaDevice, receiver_name: str
) -> Tuple[int, List]:
"""Receives a set of arrays on the outfeed for the specificied device.
Args:
receiver_name: a name used for debugging and logging
Returns: a tuple with the consumer_id, the arrays received, and
a kwargs dictionary that was passed to _emit_outfeed.
"""
platform = xla_client.get_local_backend(None).platform
header_shape = xla_client.Shape.array_shape(onp.dtype(onp.uint32),
(_OUTFEED_HEADER_LENGTH,))
def _get_data(data_shape: XlaShape, device: XlaDevice) -> XlaShape:
return xla_client.transfer_from_outfeed(data_shape, device)
header = _get_data(header_shape, device)
if header[0] != _OUTFEED_HEADER_START:
raise ValueError(f"Read unexpected outfeed header {header[0]} [{receiver_name}]")
logging.info(f"[{receiver_name}:{device}] Outfeed read header: {header}")
consumer_id = (header[1] << 32) + header[2]
metadata_length = header[3]
assert metadata_length <= _OUTFEED_HEADER_METADATA_LENGTH
metadatas = [int(header[i]).to_bytes(4, byteorder="big")
for i in range(4, 4 + (metadata_length + 3) // 4)]
metadata = b"".join(metadatas)[:metadata_length]
array_descriptors = msgpack.unpackb(metadata)
arrays_shape = [xla_client.Shape.array_shape(_CODE_TO_DTYPE[a_descr[0]],
a_descr[1])
for a_descr in array_descriptors]
entire_shape = xla_client.Shape.tuple_shape(arrays_shape)
arrays = _get_data(entire_shape, device)
logging.info(f"[{receiver_name}:{device}] Outfeed read data of shape "
",".join([f"{data.dtype}{data.shape}" for data in arrays]))
return (consumer_id, arrays)
class TapFunctionException(Exception):
"""Signals that some tap function had exceptions.
Raised by :func:`outfeed_receiver`.
"""
pass
_outfeed_receiver_started = False
@contextmanager
def outfeed_receiver(*,
timeout_sec=10,
backends: Optional[Sequence[str]] = None,
devices: Optional[Sequence[XlaDevice]] = None,
receiver_name=""):
# TODO: better timeout management.
# TODO: prevent multiple consumers.
"""Starts receivers for the :func:`id_tap` outfeed from several devices.
The receivers will run in a threadpool. The tapped functions will be invoked
in those threads. If a tap function raises an exception, an error is
printed, but the receving continues until the body of the context manager
terminates and all outfeeds from all devices have been received. Only then
will a :exc:`TapFunctionException` be raised.
Args:
backends: (optional) sequence of backend names for which to listen.
Will listen to all devices on those backends. By default, listed to
all devices on all known backends.
devices: (optional) sequence of devices to listed to. At most one
of `backends` or `devices` must be given.
receiver_name: (optional) a name to use with debug logging
Usage::
with outfeed_receiver():
jax.jit(func)(args)
...
jax.pmap(another_func)(args)
The ``outfeed_receiver`` must be started outside any jitted computation.
"""
global _outfeed_receiver_started
if _outfeed_receiver_started:
raise ValueError("At most one outfeed_receiver can be running at once.")
if not devices:
backends = backends or xla_client._get_local_backends().keys()
devices = tuple(itertools.chain(*[api.devices(backend)
for backend in backends]))
else:
if backends:
raise ValueError("At most one of `devices` or `backends` must be given.")
executor = futures.ThreadPoolExecutor(
thread_name_prefix=f"outfeed_receiver_{receiver_name}",
max_workers=len(devices))
count_tap_exceptions = 0
def device_receiver_loop(device: XlaDevice) -> XlaDevice:
"""Polls the outfeed for a device in a loop."""
nonlocal count_tap_exceptions
while (True):
consumer_id, arrays = _receive_outfeed(device, receiver_name)
if _LOGGING:
logging.info(f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} " +
(" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
if consumer_id == _end_consumer:
assert not arrays
if _LOGGING:
logging.info(f"[{receiver_name}:{device}] Outfeed received END_OUTFEED")
return device
consumer = _consumer_registry_by_id.get(consumer_id)
if consumer is None:
logging.error(f"Ignoring received outfeed for unknown tap consumer")
count_tap_exceptions += 1
continue # We need to read the entire outfeed
try:
arg = api.tree_unflatten(consumer.arg_treedef, arrays)
consumer.func(arg, **dict(consumer.kwargs)) # type: ignore[attribute-error]
except Exception as e:
logging.error(f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}")
count_tap_exceptions += 1
# We continue for now, we need to keep reading the outfeed
receiver_futures = [executor.submit(device_receiver_loop, d) for d in devices]
# Register a callback to raise errors if any. These exception come from
# bugs in our code, not from the tap functions.
for rf in receiver_futures:
rf.add_done_callback(lambda rf: rf.result())
xla.can_execute_outfeed_computations = True
try:
yield
finally:
for d in devices: # Signal the end of printing
api.jit(lambda x: id_tap(_end_consumer, None, result=x), device=d)(0) # type: ignore[arg-type]
xla.can_execute_outfeed_computations = False
for f in futures.as_completed(receiver_futures, timeout=timeout_sec):
finished_device = f.result() # Throw exceptions here
if _LOGGING:
logging.info(f"[{receiver_name}:{finished_device} Outfeed receiver finished")
if count_tap_exceptions > 0:
raise TapFunctionException
#### Jaxpr rewriting logic
####
def _jaxpr_var_defs(jaxpr: core.Jaxpr) -> Iterable[int]:
"""Iterates over all the unique vars the top-level of a Jaxpr"""
for iv in jaxpr.invars:
yield iv.count
for cv in jaxpr.constvars:
yield cv.count
for eqn in jaxpr.eqns:
for ov in eqn.outvars:
yield ov.count
def _jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive is id_tap_p:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if _jaxpr_uses_outfeed(subjaxpr):
return True
return False
def _rewrite_typed_jaxpr(tjaxpr: core.TypedJaxpr,
has_input_token: bool,
has_output_token: bool) -> Tuple[core.TypedJaxpr, bool]:
"""Rewrites a TypedJaxpr to thread the token, if needed.
Returns the rewritten Jaxpr, and whether it uses outfeed."""
new_jaxpr, uses_outfeed = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token, has_output_token)
return (core.TypedJaxpr(new_jaxpr, tjaxpr.literals,
tuple(map(lambda v: v.aval, new_jaxpr.invars)),
tuple(map(lambda v: v.aval, new_jaxpr.outvars))),
uses_outfeed)
def _rewrite_jaxpr(jaxpr: core.Jaxpr,
has_input_token: bool,
has_output_token: bool) -> Tuple[core.Jaxpr, bool]:
"""Rewrite a Jaxpr to thread the token, if needed."""
assert has_input_token or not has_output_token
uses_outfeed = _jaxpr_uses_outfeed(jaxpr)
if not has_input_token and not uses_outfeed:
return (jaxpr, False)
max_var_count = max(_jaxpr_var_defs(jaxpr))
mk_new_id = itertools.count(start=max_var_count + 1)
def mk_new_var(aval: core.AbstractValue) -> core.Var:
return core.Var(next(mk_new_id), '', aval)
eqns: List[core.JaxprEqn] = []
last_token_var = mk_new_var(core.abstract_token)
if has_input_token:
invars = jaxpr.invars + [last_token_var]
else:
invars = jaxpr.invars
eqns.append(core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
lax.create_token_p, {}))
for eqn in jaxpr.eqns:
if eqn.primitive is id_tap_p:
new_token_var = mk_new_var(core.abstract_token)
eqns.append(core.new_jaxpr_eqn(eqn.invars + [last_token_var],
eqn.outvars + [new_token_var],
eqn.primitive, eqn.params))
last_token_var = new_token_var
elif eqn.primitive is lax.while_p:
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
if _jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
raise NotImplementedError("outfeed not supported in the conditional of a while")
uses_outfeed = _jaxpr_uses_outfeed(body_jaxpr.jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
new_token_var = mk_new_var(core.abstract_token)
eqns.append(core.new_jaxpr_eqn(
eqn.invars + [last_token_var],
eqn.outvars + [new_token_var],
eqn.primitive,
dict(eqn.params,
body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0],
cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0])))
last_token_var = new_token_var
elif eqn.primitive is lax.cond_p:
true_jaxpr, false_jaxpr, linear = util.split_dict(
eqn.params, ["true_jaxpr", "false_jaxpr", "linear"])
uses_outfeed = _jaxpr_uses_outfeed(true_jaxpr.jaxpr) or _jaxpr_uses_outfeed(false_jaxpr.jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
nr_true_invars = len(true_jaxpr.jaxpr.invars)
pred, true_invars, false_invars = util.split_list(eqn.invars,
[1, nr_true_invars])
new_token_var = mk_new_var(core.abstract_token)
new_invars = pred + true_invars + [last_token_var] + false_invars + [last_token_var]
eqns.append(core.new_jaxpr_eqn(
new_invars, eqn.outvars + [new_token_var],
eqn.primitive,
dict(eqn.params,
true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True, True)[0],
false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True, True)[0],
linear=linear + (False, False))))
last_token_var = new_token_var
elif eqn.primitive is lax.scan_p:
num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
eqn.params, ["num_consts", "num_carry", "jaxpr", "linear",
"reverse", "length"])
uses_outfeed = _jaxpr_uses_outfeed(carry_jaxpr.jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
nr_const_and_carry = num_consts + num_carry
new_invars = eqn.invars[0:nr_const_and_carry] + [last_token_var] + eqn.invars[nr_const_and_carry:]
new_token_var = mk_new_var(core.abstract_token)
new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
# The rewrite put the token carry at end, it has to be at end of carry
new_jaxpr_invars = new_jaxpr.jaxpr.invars
new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
[new_jaxpr_invars[-1]] +
new_jaxpr_invars[nr_const_and_carry:-1])
new_jaxpr.jaxpr.invars = new_jaxpr_invars
new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
[new_jaxpr_outvars[-1]] +
new_jaxpr_outvars[num_carry:-1])
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
eqns.append(core.new_jaxpr_eqn(
new_invars,
# Output token is at the end of carry result
eqn.outvars[0:num_carry] + [new_token_var] + eqn.outvars[num_carry:],
eqn.primitive,
dict(eqn.params,
jaxpr=new_jaxpr,
num_carry=num_carry + 1,
linear=linear + (False,))))
last_token_var = new_token_var
elif eqn.primitive is xla.xla_call_p:
call_jaxpr = eqn.params["call_jaxpr"]
uses_outfeed = _jaxpr_uses_outfeed(call_jaxpr)
if not uses_outfeed:
eqns.append(eqn)
continue
new_token_var = mk_new_var(core.abstract_token)
eqns.append(core.new_jaxpr_eqn(
eqn.invars + [last_token_var],
eqn.outvars + [new_token_var],
eqn.primitive,
dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0])))
last_token_var = new_token_var
elif eqn.primitive is pxla.xla_pmap_p:
raise NotImplementedError("rewrite of pmap")
else:
# Check no more subjaxprs
for param in eqn.params.values():
if type(param) is core.Jaxpr or type(param) is core.TypedJaxpr:
assert False
eqns.append(eqn)
outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
return (core.Jaxpr(jaxpr.constvars, invars, outvars, eqns), True)
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)