mirror of
https://github.com/ROCm/jax.git
synced 2025-04-28 05:56:06 +00:00
901 lines
36 KiB
Python
901 lines
36 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Primitives for calling user-defined Python functions
|
|
on the host, even from compiled and transformed code.
|
|
|
|
**Experimental: please give feedback, and expect changes.**
|
|
|
|
Host callbacks work even for code executed on accelerators and
|
|
even for code under JAX transformations. A few examples::
|
|
|
|
# calls func(2x) and returns 2x
|
|
y = id_tap(func, x * 2)
|
|
# calls func((2x, 3x)) and returns (2x, 3x)
|
|
y, z = id_tap(func, (x * 2, x * 3)) # The argument can be a pytree
|
|
# calls func(2x) and returns y
|
|
y = id_tap(func, x * 2, result=y)
|
|
# calls func(2x, what='x') and returns 2x
|
|
y = id_tap(func, x * 2, what='x')
|
|
# calls func(dict(x=x, y=y), what='foo') and returns dict(x=x, y=y)
|
|
x, y = id_tap(func, dict(x=x, y=y), what='a dict')
|
|
|
|
The order of execution is by data dependency: after all the arguments are
|
|
computed and before the result is used. At least one of the returned values
|
|
must be used in the rest of the computation, or else this operation has no effect.
|
|
|
|
**At the moment**, in order to use the callback primitives one must wrap any
|
|
code that uses them with an :func:`outfeed_receiver` (an error is
|
|
raised otherwise)::
|
|
|
|
with outfeed_receiver():
|
|
id_print(x)
|
|
jax.jit(func_with_taps)(x)
|
|
|
|
The printing and the tap functions execute in separate threads that are started
|
|
by :func:`outfeed_receiver`. There is one thread per device. This ensures
|
|
that the outfeed from a certain device will come in order. Exceptions from
|
|
the user-define tap functions are printed along with the traceback, but the
|
|
outfeed listening does not stop until the body of the
|
|
:func:`outfeed_receiver` terminates and all the outfeeds are received.
|
|
At that point, a ``TapFunctionException`` is raised if there was an exception
|
|
in one of the tap functions.
|
|
|
|
**We intend to implement an alternative outfeed receiving mechanism that will
|
|
not require the user to start the `outfeed_receiver`.**
|
|
|
|
The current implementation uses the outfeed mechanism provided by XLA. The
|
|
mechanism itself is quite primitive in the sense that a receiver must know
|
|
exactly the shape of each incoming packet, and how many packets are expected.
|
|
This makes it hard to use for multiple kinds of data in the same computation,
|
|
and it is practically impossible to use it under conditionals.
|
|
|
|
The tapping API introduced here makes it easy to share the outfeed mechanism
|
|
for multiple purposes. Instead of using it directly, just use :func:`id_tap`.
|
|
|
|
We describe the behaviour under transformations in the context of the
|
|
following function definition::
|
|
|
|
def power3(x):
|
|
y = x * x
|
|
_, y = id_print(x, y, what="x,x^2")
|
|
return y * x
|
|
|
|
During JAX transformations the special parameters ``transforms`` is extended
|
|
with a dictionary, containing the key ``name`` holding the name of the
|
|
transformation and additional keys holding transformation parameters, if
|
|
applicable.
|
|
|
|
For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended
|
|
with transformation name ``batch`` and ``batch_dims`` set to the the tuple of
|
|
batched dimensions (one entry per argument, ``None`` denotes an argument that
|
|
was broadcast)::
|
|
|
|
jax.vmap(power3)(np.arange(3.))
|
|
# what=x,x^2 transforms=({name=batch, batch_dims=(0, 0)}): ([0, 1, 2], [0, 1, 4])
|
|
|
|
For :func:`jax.jvp` there will be two callbacks, one with the values of
|
|
the primals and one with the tangents::
|
|
|
|
jax.jvp(power3, (3.,), (0.1,))
|
|
# what=x,x^2: (3., 9.)
|
|
# what=x,x^2 transforms={name=jvp}: (0.1, 0.6)
|
|
|
|
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the
|
|
values of the adjoints for the arguments. You may also see a callback with
|
|
the values of the primals from the forward pass, if those values are needed for
|
|
the backward pass::
|
|
|
|
jax.grad(power3)(3.)
|
|
# what=x,x^2: (3., 9.) # from forward pass, since y is needed in backward pass
|
|
# what=x,x^2 transforms=({name=jvp}, {name=transpose}): (0., 3.) # from backward pass, adjoints of _, y
|
|
|
|
See documentation for :func:`id_tap` and :func:`id_print`.
|
|
For usage example, see tests/host_callback_test.py.
|
|
|
|
Still to do:
|
|
* Performance tests.
|
|
* Add flags for logging.
|
|
* Add unit tests with mocks.
|
|
* Improve the ergonomics of starting the consumer loop. Currently, when
|
|
invoking jit-ed code, one must start a consumer loop. This is not needed
|
|
when invoking code that does not involve jit. There is an error when
|
|
attempting to start a compiled computation without starting the outfeed
|
|
receiver. Perhaps we can put the receiver threads in the runtime.
|
|
* Explore a simpler API that uses Python program-order, instead of
|
|
data dependency-order.
|
|
* Explore implementation with outside compilation.
|
|
|
|
"""
|
|
from concurrent import futures
|
|
from contextlib import contextmanager
|
|
import itertools
|
|
|
|
from jax import api
|
|
from jax import core
|
|
from jax import dtypes
|
|
from jax import lax
|
|
from jax.lib import pytree
|
|
from jax.interpreters import ad, xla, batching, masking
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax import pprint_util as ppu
|
|
from jax import util
|
|
from jaxlib import xla_client
|
|
from jaxlib import xla_extension
|
|
from jaxlib import version as jaxlib_version
|
|
|
|
import logging
|
|
import msgpack # type: ignore
|
|
import numpy as np
|
|
import traceback
|
|
from typing import (Any, Callable, Dict, Iterable, List, Optional, NamedTuple,
|
|
Sequence, Tuple, cast)
|
|
|
|
xops = xla_client._xla.ops
|
|
|
|
# TODO(necula): fix mypy errors if I define the type aliases below
|
|
XlaOp = Any # xla_extension.XlaOp
|
|
XlaShape = Any # xla_client.Shape
|
|
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
|
|
XlaDevice = Any # xla_client.Device
|
|
|
|
# TODO: add a flag
|
|
_LOGGING = True
|
|
|
|
# Starting with 0.1.46 the outfeed is now on the Device
|
|
# TODO(necula): remove once we fix XLA
|
|
_jaxlib_version = tuple(int(x)for x in jaxlib_version.__version__.split('.'))
|
|
if _jaxlib_version == (0, 1, 45):
|
|
_OUTFEED_MECHANISM = "xla_client"
|
|
elif _jaxlib_version >= (0, 1, 47):
|
|
_OUTFEED_MECHANISM = "device"
|
|
else:
|
|
_OUTFEED_MECHANISM = "none"
|
|
|
|
|
|
def id_tap(func: Callable, arg, *,
|
|
result=None,
|
|
**kwargs):
|
|
"""Host-callback tap primitive, like identity function with a call to ``func``.
|
|
|
|
**Experimental: please give feedback, and expect changes!**
|
|
|
|
``id_tap`` behaves semantically like the identity function but has the side-effect
|
|
that a user-defined Python function is called with the runtime values of the
|
|
argument.
|
|
|
|
Args:
|
|
* arg: the argument passed to the tap function, can be a pytree of JAX
|
|
types.
|
|
* result: if given, specifies the return value of ``id_tap``. This value is
|
|
not passed to the tap function, and in fact is not sent from the device
|
|
to the host. If the ``result`` parameter is not specified then the return
|
|
value of ``id_tap`` is ``arg``.
|
|
* kwargs: will be passed directly to the tap function. Can be anything that
|
|
is hashable, these are kept in the host Python process until outfeeds are
|
|
received.
|
|
|
|
Returns:
|
|
* ``arg``, or ``result`` if given.
|
|
|
|
The order of execution is by data dependency: after all the arguments and
|
|
the value of ``result`` if present, are computed and before the returned
|
|
value is used. At least one of the returned values of ``id_tap`` must be
|
|
used in the rest of the computation, or else this operation has no effect.
|
|
|
|
If you want to tap a constant value, you should use the ``result`` parameter
|
|
to control when it is tapped, otherwise it will be tapped during tracing
|
|
of the function::
|
|
|
|
x = id_tap(42, result=x)
|
|
|
|
Tapping works even for code executed on accelerators and even for code under
|
|
JAX transformations. Code that uses taps must be run embedded in
|
|
:func:`outfeed_receiver`.
|
|
|
|
For more details see the
|
|
`module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
|
|
"""
|
|
if _OUTFEED_MECHANISM == "none":
|
|
raise NotImplementedError("id_tap works only with jaxlib 0.1.47 and higher")
|
|
if func not in (_end_consumer, _unknown_testing_consumer):
|
|
api._check_callable(func)
|
|
flat_args, arg_treedef = pytree.flatten(arg)
|
|
for arg in flat_args: api._check_arg(arg)
|
|
params = dict(kwargs) # we pass a copy of params to the primitive
|
|
# See definition of id_tap_p for what parameters it takes
|
|
params["func"] = func
|
|
params["arg_treedef"] = arg_treedef
|
|
if result is not None:
|
|
flat_results, result_treedef = pytree.flatten(result)
|
|
for result in flat_results: api._check_arg(result)
|
|
all_args = flat_args + flat_results
|
|
params["nr_untapped"] = len(flat_results)
|
|
else:
|
|
all_args = flat_args
|
|
flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args
|
|
if result is not None:
|
|
return result_treedef.unflatten(flat_outs[-params["nr_untapped"]:]) # type: ignore[unsupported-operands]
|
|
else:
|
|
return arg_treedef.unflatten(flat_outs)
|
|
|
|
|
|
def id_print(arg, *, result=None, output_stream=None, threshold=None,
|
|
**kwargs):
|
|
"""Like :func:`id_tap` with a printing tap function.
|
|
|
|
**Experimental: please give feedback, and expect changes!**
|
|
|
|
On each invocation of the printing tap, the ``kwargs`` if present
|
|
will be printed first (sorted by keys). Then arg will be printed,
|
|
with the arrays stringified with ``numpy.array2string``.
|
|
|
|
See the :func:`id_tap` documentation.
|
|
|
|
Additional keyword arguments:
|
|
|
|
* ``output_stream`` if given then it will be used instead of the
|
|
built-in ``print``. The string will be passed as ``output_stream.write(s)``.
|
|
* ``threshold`` is passed to ``numpy.array2string``.
|
|
"""
|
|
return id_tap(_print_consumer, arg,
|
|
result=result, output_stream=output_stream,
|
|
threshold=threshold, **kwargs)
|
|
|
|
|
|
# A registry of outfeed consumers, used upon receiving outfeeds
|
|
class _ConsumerCallable(NamedTuple):
|
|
"""Host-side information for an outfeed consumer."""
|
|
func: Callable
|
|
kwargs: Tuple[Tuple[str, Any], ...]
|
|
arg_treedef: Any
|
|
|
|
def unpack_kwargs(self):
|
|
kwargs = dict(self.kwargs)
|
|
transforms = kwargs.get("transforms")
|
|
if transforms is None:
|
|
return kwargs
|
|
else:
|
|
def unpack_transform(name, *params):
|
|
if name == "batch":
|
|
return dict(name=name, batch_dims=params[0])
|
|
elif name == "mask":
|
|
return dict(name=name, logical_shapes=params[0])
|
|
else:
|
|
assert not params, f"{name}, {params}"
|
|
return dict(name=name)
|
|
return dict(kwargs,
|
|
transforms=tuple([unpack_transform(*t) for t in transforms]))
|
|
|
|
_consumer_registry: Dict[_ConsumerCallable, int] = dict()
|
|
_consumer_registry_by_id: Dict[int, _ConsumerCallable] = dict()
|
|
|
|
|
|
def _register_consumer(cons: _ConsumerCallable) -> int:
|
|
"""Registers a tap function, cache by hash of cons."""
|
|
cons_id = _consumer_registry.get(cons)
|
|
if cons_id is not None:
|
|
return cons_id
|
|
cons_id = hash(cons)
|
|
_consumer_registry[cons] = cons_id
|
|
_consumer_registry_by_id[cons_id] = cons
|
|
return cons_id
|
|
|
|
def _print_consumer(arg, *, output_stream=None,
|
|
threshold=1024, **kwargs):
|
|
"""The consumer for id_print.
|
|
|
|
We provide this as a simple tapping function for printing.
|
|
This is **experimental**
|
|
and may not want to add many features to it; it should be easy for the user
|
|
to roll their own printing function.
|
|
|
|
Args:
|
|
output_stream: a function whose `write` method is called with the strings
|
|
to be output.
|
|
threshold: the value of numpy.array2string threshold parameter.
|
|
**kwargs: all other keyword args are printed before printing `arg`.
|
|
"""
|
|
def emit_str(s: str):
|
|
if output_stream is not None:
|
|
output_stream.write(s + "\n")
|
|
else:
|
|
print(s)
|
|
kv_pairs = " ".join([f"{k}: {v}"
|
|
for k, v in sorted(kwargs.items())
|
|
if k not in ("consumer_id", "nr_untapped")])
|
|
if kv_pairs:
|
|
emit_str(kv_pairs)
|
|
|
|
def pp_val(arg) -> ppu.PrettyPrint:
|
|
if isinstance(arg, (tuple, list)):
|
|
return (ppu.pp('[ ') >>
|
|
ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(' ]'))
|
|
elif isinstance(arg, dict):
|
|
return (ppu.pp('{ ') >>
|
|
ppu.vcat(
|
|
[ppu.pp(f"{k}=") >> pp_val(v)
|
|
for k, v in sorted(arg.items())]) >>
|
|
ppu.pp(' }'))
|
|
elif isinstance(arg, np.ndarray):
|
|
return ppu.pp(np.array2string(arg, threshold=threshold))
|
|
else:
|
|
return ppu.pp(str(arg))
|
|
|
|
emit_str(str(pp_val(arg)))
|
|
|
|
|
|
"""The id_tap primitive acts like the identity function. It has a number of
|
|
positional arguments and parameters:
|
|
* func: the actual (Python) function to invoke with the positional arguments
|
|
and the parameters.
|
|
* nr_untapped: how many positional arguments (from the tail) should not be
|
|
passed to the tap function.
|
|
* arg_treedef: the treedef of the tapped positional arguments.
|
|
* transforms: a tuple of the transformations that have been applied. Each
|
|
element of the tuple is itself a tuple with the first element the name
|
|
of the transform. The remaining elements depend on the transform. For
|
|
example, for `batch`, the parameters are the dimensions that have been
|
|
batched, and for `mask` the logical shapes. These are unpacked by
|
|
_ConsumerCallable before passing to the user function.
|
|
|
|
* the remaining parameters are passed to the tap function.
|
|
"""
|
|
id_tap_p = core.Primitive("id_tap")
|
|
id_tap_p.multiple_results = True
|
|
xla.outfeed_primitives.add(id_tap_p)
|
|
|
|
def _add_transform(params: Dict, name: str,
|
|
*transform_params) -> Dict:
|
|
"""Adds the `transform` to the params["transforms"].
|
|
|
|
Uses a tuple representation internally, will be unpacked before the
|
|
callback by _ConsumerCallable.
|
|
"""
|
|
new_transform = (name, *transform_params)
|
|
return dict(params, transforms=(params.get("transforms", ())
|
|
+ (new_transform,)))
|
|
|
|
def _id_tap_impl(*arrays, **params):
|
|
# We use the jitted-version of the primitive even for eager execution, both
|
|
# so that we do not duplicate logic, but also so that all outfeed is received
|
|
# by the outfeed_listeners, in the same thread from a given device. If we were
|
|
# to process the tap here, it would be coming from the main thread. Also,
|
|
# even in eager execution some primitives, such as while, are compiled.
|
|
# It would be confusing to process a sequence "id_tap; while" in two
|
|
# different threads.
|
|
return xla.apply_primitive(id_tap_p, *arrays, **params)
|
|
|
|
id_tap_p.def_impl(_id_tap_impl)
|
|
|
|
def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \
|
|
-> Sequence[pe.AbstractValue]:
|
|
return args_a
|
|
|
|
|
|
id_tap_p.def_abstract_eval(_id_tap_abstract_eval)
|
|
|
|
# TODO(necula): there must be a better way to do this.
|
|
# The AttributeError is for regular values, the KeyError is for ConcreteArray
|
|
def _instantiate_zeros(arg, tan):
|
|
"""Turn special ad.zero tangents into arrays of 0s."""
|
|
if type(tan) is not ad.Zero:
|
|
return tan
|
|
|
|
try:
|
|
aval = arg.aval
|
|
return ad.instantiate_zeros_aval(aval, tan)
|
|
except (AttributeError, KeyError):
|
|
# We get here for regular Python values
|
|
return ad.zeros_like_jaxval(arg)
|
|
|
|
|
|
def _id_tap_jvp_rule(primals, tangents, *, func, nr_untapped=0, **params):
|
|
# Put primals through id_tap separately, so that partial evaluation
|
|
# can do its job when they are known (for grad)
|
|
out_primals = id_tap_p.bind(*primals, func=func, nr_untapped=nr_untapped, **params)
|
|
# Add one primal output as untapped, to create data dependency.
|
|
tangent_zeros = tuple(map(_instantiate_zeros, primals, tangents))
|
|
out_tangents_extra = id_tap_p.bind(*tangent_zeros, out_primals[0],
|
|
func=func, nr_untapped=nr_untapped + 1,
|
|
**_add_transform(params, "jvp"))
|
|
return tuple(out_primals), tuple(out_tangents_extra[:-1])
|
|
|
|
ad.primitive_jvps[id_tap_p] = _id_tap_jvp_rule
|
|
|
|
|
|
def _id_tap_transpose_rule(cts, *args, func=None, nr_untapped=0, **params):
|
|
assert len(cts) == len(args)
|
|
cts_zeros = tuple(map(_instantiate_zeros, args, cts))
|
|
ct_args = id_tap_p.bind(*cts_zeros, func=func, nr_untapped=nr_untapped,
|
|
**_add_transform(params, "transpose"))
|
|
return ct_args
|
|
|
|
ad.primitive_transposes[id_tap_p] = _id_tap_transpose_rule
|
|
|
|
|
|
def _id_tap_batching_rule(batched_args, batch_dims, **params):
|
|
new_params = _add_transform(params, "batch", batch_dims)
|
|
res = id_tap_p.bind(*batched_args, **new_params)
|
|
return res, batch_dims
|
|
|
|
batching.primitive_batchers[id_tap_p] = _id_tap_batching_rule
|
|
|
|
# def _id_tap_shape_rule(*operands, **params):
|
|
# return tuple([op.shape for op in operands])
|
|
|
|
# TODO(necula): these disappeared
|
|
# masking.shape_rules[id_tap_p] = _id_tap_shape_rule # type: ignore[module-attr]
|
|
|
|
def _id_tap_masking_rule(operands, operands_logical_shapes, **params):
|
|
new_params = _add_transform(params, "mask",
|
|
operands_logical_shapes)
|
|
return id_tap_p.bind(*operands, **new_params)
|
|
|
|
masking.masking_rules[id_tap_p] = _id_tap_masking_rule
|
|
|
|
####
|
|
#### XLA compilation ####
|
|
####
|
|
# Special consumer to mark the end of outfeed stream for a device
|
|
_end_consumer = 0
|
|
_unknown_testing_consumer = 1 # for testing error cases
|
|
|
|
|
|
def _id_tap_translation_rule_outfeed(comp: XlaComputationBuilder,
|
|
*args_op: XlaOp, func=None,
|
|
nr_untapped=0, arg_treedef=None,
|
|
**params):
|
|
params = dict(params)
|
|
if func in (_end_consumer, _unknown_testing_consumer):
|
|
params["consumer_id"] = func
|
|
else:
|
|
params["consumer_id"] = _register_consumer(
|
|
_ConsumerCallable(func, tuple(sorted(params.items())), arg_treedef))
|
|
|
|
# We expect the current token at the end, inserted by _rewrite_jaxpr.
|
|
current_token = args_op[-1]
|
|
assert not comp.get_shape(current_token).is_array(), "The last argument must be a token"
|
|
|
|
nr_args_to_emit = len(args_op) - nr_untapped - 1
|
|
next_token = _emit_outfeed(comp, current_token,
|
|
args_op[0:nr_args_to_emit], params["consumer_id"])
|
|
results = (args_op[:-1] + (next_token,))
|
|
return xops.Tuple(comp, results)
|
|
|
|
xla.translations[id_tap_p] = _id_tap_translation_rule_outfeed
|
|
|
|
####
|
|
#### Jaxpr rewriting logic to thread the tokens through stateful primitives.
|
|
####
|
|
|
|
|
|
# TODO: we should really get rid of TypedJaxpr
|
|
def _mk_typed_jaxpr(jaxpr: core.Jaxpr, literals: Sequence) -> core.TypedJaxpr:
|
|
return core.TypedJaxpr(jaxpr, literals,
|
|
tuple(map(lambda v: v.aval, jaxpr.invars)),
|
|
tuple(map(lambda v: v.aval, jaxpr.outvars)))
|
|
|
|
def _rewrite_typed_jaxpr(tjaxpr: core.TypedJaxpr,
|
|
has_input_token: bool,
|
|
has_output_token: bool) -> Tuple[core.TypedJaxpr, bool]:
|
|
"""Rewrites a TypedJaxpr to thread the token, if needed.
|
|
|
|
Returns the rewritten Jaxpr, and whether it uses outfeed."""
|
|
new_jaxpr, uses_outfeed = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token,
|
|
has_output_token)
|
|
return _mk_typed_jaxpr(new_jaxpr, tjaxpr.literals), uses_outfeed
|
|
|
|
|
|
def _rewrite_jaxpr(jaxpr: core.Jaxpr,
|
|
has_input_token: bool,
|
|
has_output_token: bool) -> Tuple[core.Jaxpr, bool]:
|
|
"""Rewrite a Jaxpr to thread the token, if needed."""
|
|
assert has_input_token or not has_output_token
|
|
|
|
if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
|
|
return (jaxpr, False)
|
|
|
|
mk_new_var = core.gensym([jaxpr])
|
|
|
|
eqns: List[core.JaxprEqn] = []
|
|
last_token_var = mk_new_var(core.abstract_token) # store the incoming token
|
|
if has_input_token:
|
|
invars = jaxpr.invars + [last_token_var]
|
|
else:
|
|
invars = jaxpr.invars
|
|
eqns.append(core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
|
|
lax.create_token_p, {}))
|
|
|
|
for eqn in jaxpr.eqns:
|
|
if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
|
eqns.append(eqn)
|
|
else:
|
|
output_token_var = mk_new_var(core.abstract_token)
|
|
_rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var)
|
|
last_token_var = output_token_var
|
|
|
|
outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
|
|
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
|
|
return (new_jaxpr, True)
|
|
|
|
def _rewrite_eqn(eqn: core.JaxprEqn,
|
|
eqns: List[core.JaxprEqn],
|
|
input_token_var: core.Var,
|
|
output_token_var: core.Var,
|
|
mk_new_var: Callable[[core.AbstractValue], core.Var]):
|
|
"""Rewrite an `eqn` and append equations to `eqns`. Assume that the
|
|
current token is in `input_token_var` and the resulting token must end in
|
|
`output_token_var`."""
|
|
if eqn.primitive is id_tap_p:
|
|
eqns.append(core.new_jaxpr_eqn(eqn.invars + [input_token_var],
|
|
eqn.outvars + [output_token_var],
|
|
eqn.primitive, eqn.params))
|
|
elif eqn.primitive is lax.while_p:
|
|
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
|
|
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
|
if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
|
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
|
|
mk_new_var)
|
|
return
|
|
|
|
eqns.append(core.new_jaxpr_eqn(
|
|
eqn.invars + [input_token_var],
|
|
eqn.outvars + [output_token_var],
|
|
eqn.primitive,
|
|
dict(eqn.params,
|
|
body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0],
|
|
cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0])))
|
|
elif eqn.primitive is lax.cond_p:
|
|
branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
|
|
nr_operands = len(branches[0].jaxpr.invars)
|
|
index, *operands = eqn.invars
|
|
new_invars = [index, *operands, input_token_var]
|
|
eqns.append(core.new_jaxpr_eqn(
|
|
new_invars, eqn.outvars + [output_token_var],
|
|
eqn.primitive,
|
|
dict(eqn.params,
|
|
branches=tuple(
|
|
_rewrite_typed_jaxpr(jaxpr, True, True)[0]
|
|
for jaxpr in branches),
|
|
linear=(*linear, False))))
|
|
elif eqn.primitive is lax.scan_p:
|
|
num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
|
|
eqn.params, ["num_consts", "num_carry", "jaxpr", "linear",
|
|
"reverse", "length"])
|
|
# We add the token right at the end of carry
|
|
nr_const_and_carry = num_consts + num_carry
|
|
new_invars = eqn.invars[0:nr_const_and_carry] + [input_token_var] + eqn.invars[nr_const_and_carry:]
|
|
new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0]
|
|
# The rewrite has put the token at end, it has to be at end of carry
|
|
new_jaxpr_invars = new_jaxpr.jaxpr.invars
|
|
new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] +
|
|
[new_jaxpr_invars[-1]] +
|
|
new_jaxpr_invars[nr_const_and_carry:-1])
|
|
new_jaxpr.jaxpr.invars = new_jaxpr_invars
|
|
new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars]
|
|
|
|
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
|
|
new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] +
|
|
[new_jaxpr_outvars[-1]] +
|
|
new_jaxpr_outvars[num_carry:-1])
|
|
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
|
|
new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars]
|
|
eqns.append(core.new_jaxpr_eqn(
|
|
new_invars,
|
|
# Output token is at the end of carry result
|
|
eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:],
|
|
eqn.primitive,
|
|
dict(eqn.params,
|
|
jaxpr=new_jaxpr,
|
|
num_carry=num_carry + 1,
|
|
linear=linear + (False,))))
|
|
elif eqn.primitive is xla.xla_call_p:
|
|
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
|
eqns.append(core.new_jaxpr_eqn(
|
|
eqn.invars + [input_token_var],
|
|
eqn.outvars + [output_token_var],
|
|
eqn.primitive,
|
|
dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0])))
|
|
else:
|
|
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
|
|
|
|
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn,
|
|
eqns: List[core.JaxprEqn],
|
|
input_token_var: core.Var,
|
|
output_token_var: core.Var,
|
|
mk_new_var: Callable):
|
|
"""Rewrite a while whose cond has outfeed"""
|
|
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
|
|
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
|
transformed_cond_jaxpr, _ = _rewrite_typed_jaxpr(cond_jaxpr, True, True)
|
|
carry_invars = eqn.invars[cond_nconsts+body_nconsts:]
|
|
# pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token)
|
|
pred1_and_token1 = [mk_new_var(ov.aval)
|
|
for ov in transformed_cond_jaxpr.jaxpr.outvars]
|
|
eqns.append(core.new_jaxpr_eqn(
|
|
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var],
|
|
pred1_and_token1,
|
|
xla.xla_call_p,
|
|
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
|
name="cond_before")))
|
|
# Make a new cond "lambda pred, carry, token: pred"
|
|
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
|
|
new_cond_invars = (
|
|
[new_cond_pred_invar] +
|
|
[mk_new_var(cv.aval) for cv in carry_invars] +
|
|
[mk_new_var(core.abstract_token)])
|
|
new_cond_jaxpr = _mk_typed_jaxpr(core.Jaxpr([], new_cond_invars,
|
|
[new_cond_pred_invar], []),
|
|
[])
|
|
# Make a new body: "lambda cond_constvars, body_constvars, pred, carry, token:
|
|
# carry2, token2 = rewrite(BODY)(body_constvars, carry, token)
|
|
# pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2)
|
|
# (pred2, carry2, token3)
|
|
transformed_body_jaxpr, _ = _rewrite_typed_jaxpr(body_jaxpr, True, True)
|
|
new_body_invars_cond_constvars = [mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]]
|
|
new_body_invars_body_constvars = [mk_new_var(v.aval)
|
|
for v in eqn.invars[cond_nconsts:cond_nconsts+body_nconsts]]
|
|
new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0])
|
|
new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars]
|
|
new_body_invars_token = mk_new_var(core.abstract_token)
|
|
|
|
new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars]
|
|
new_body_token2 = mk_new_var(core.abstract_token)
|
|
new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0])
|
|
new_body_token3 = mk_new_var(core.abstract_token)
|
|
|
|
new_body_eqns = [
|
|
core.new_jaxpr_eqn(
|
|
new_body_invars_body_constvars +
|
|
new_body_invars_carry + [new_body_invars_token],
|
|
new_body_carry2 + [new_body_token2],
|
|
xla.xla_call_p,
|
|
dict(call_jaxpr=transformed_body_jaxpr.jaxpr,
|
|
name="body")),
|
|
core.new_jaxpr_eqn(
|
|
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
|
|
[new_body_pred2, new_body_token3],
|
|
xla.xla_call_p,
|
|
dict(call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
|
name="cond_body"))
|
|
]
|
|
new_body_jaxpr = _mk_typed_jaxpr(
|
|
core.Jaxpr([],
|
|
(new_body_invars_cond_constvars + new_body_invars_body_constvars +
|
|
[new_body_invars_pred] + new_body_invars_carry +[new_body_invars_token]),
|
|
([new_body_pred2] + new_body_carry2 + [new_body_token3]),
|
|
new_body_eqns), [])
|
|
|
|
pred_out = mk_new_var(cond_jaxpr.out_avals[0])
|
|
eqns.append(core.new_jaxpr_eqn(
|
|
(eqn.invars[0:cond_nconsts+body_nconsts] + [pred1_and_token1[0]] +
|
|
carry_invars + [pred1_and_token1[1]]),
|
|
([pred_out] + eqn.outvars + [output_token_var]),
|
|
lax.while_p,
|
|
dict(cond_jaxpr=new_cond_jaxpr, cond_nconsts=0,
|
|
body_jaxpr=new_body_jaxpr, body_nconsts=cond_nconsts + body_nconsts))
|
|
)
|
|
|
|
|
|
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
|
|
|
|
# The data on the outfeed follows a protocol that allows multiplexing the
|
|
# outfeed among multiple consumers, and communicates in-stream shape and
|
|
# type of the data.
|
|
# Each batch of array data is preceeded by a header message, of type
|
|
# uint32[_OUTFEED_HEADER_LENGTH]:
|
|
# [0]: special header value (271828)
|
|
# [1, 2]: a consumer id (64-bits, big-endian encoding as uint32[2]). The
|
|
# consumer id encodes the tap function (by hash), the
|
|
# descriptor of the arrays to be outfed, and the kwargs (a sorted tuple
|
|
# of keys and values).
|
|
# [3]: the metadata length in bytes. The metadata is a msgpack-encoded value of type:
|
|
# [ (type_code, (d0, d1, ...)), ...] # for each array, element type code
|
|
# # and the dimensions.
|
|
# padded with 0s to _OUTFEED_HEADER_LENGTH
|
|
#
|
|
#
|
|
_OUTFEED_HEADER_LENGTH = 32 # In uint32 words
|
|
_OUTFEED_HEADER_START = 271828 # [0]
|
|
# consumer_id [1, 2]
|
|
# metadata_length in bytes [3]
|
|
_OUTFEED_HEADER_METADATA_LENGTH = 4 * (_OUTFEED_HEADER_LENGTH - 4)
|
|
|
|
_CODE_TO_DTYPE = {
|
|
0: np.dtype(np.int8),
|
|
1: np.dtype(np.int16),
|
|
2: np.dtype(np.int32),
|
|
3: np.dtype(np.int64),
|
|
4: np.dtype(np.uint8),
|
|
5: np.dtype(np.uint16),
|
|
6: np.dtype(np.uint32),
|
|
7: np.dtype(np.uint64),
|
|
8: np.dtype(np.bool),
|
|
9: np.dtype(np.float16),
|
|
10: np.dtype(np.float32),
|
|
11: np.dtype(np.float64),
|
|
12: np.dtype(dtypes.bfloat16),
|
|
}
|
|
_DTYPE_STR_TO_CODE = dict([(str(d), c) for c, d in _CODE_TO_DTYPE.items()])
|
|
|
|
|
|
def _emit_outfeed(comp: XlaComputationBuilder, token: XlaOp,
|
|
arrays: Sequence[XlaOp], consumer_id: int) -> XlaOp:
|
|
"""Emits the arrays to the outfeed for the current device."""
|
|
arrays_shape = [comp.get_shape(a) for a in arrays]
|
|
def _array_shape_to_tuple(a_shape: XlaShape):
|
|
# (element_type_code, (d0, d1, ..., dn))
|
|
return (_DTYPE_STR_TO_CODE[str(np.dtype(a_shape.element_type()))],
|
|
a_shape.dimensions())
|
|
metadata = msgpack.dumps(tuple(map(_array_shape_to_tuple, arrays_shape)))
|
|
metadata_len = len(metadata)
|
|
if metadata_len > _OUTFEED_HEADER_METADATA_LENGTH:
|
|
# TODO(necula): configurable metadata length
|
|
raise ValueError("Outfeed metadata too long")
|
|
metadata += b" " * (((metadata_len + 3) // 4) * 4 - metadata_len) # pad
|
|
header = ((_OUTFEED_HEADER_START,
|
|
(consumer_id >> 32) & 0xffffffff, (consumer_id & 0xffffffff),
|
|
metadata_len) +
|
|
tuple([int.from_bytes(metadata[i:i+4], byteorder="big")
|
|
for i in range(0, _OUTFEED_HEADER_METADATA_LENGTH, 4)]))
|
|
header += (0,) * (_OUTFEED_HEADER_LENGTH - len(header))
|
|
data = xops.ConstantLiteral(comp, np.array(header, dtype=np.uint32))
|
|
token = xops.OutfeedWithToken(data, token, comp.get_shape(data))
|
|
|
|
# Now send the arrays, all at once
|
|
entire_shape = xla_client.Shape.tuple_shape(arrays_shape)
|
|
token = xops.OutfeedWithToken(xops.Tuple(comp, arrays), token, entire_shape)
|
|
return token
|
|
|
|
def _receive_outfeed(device: XlaDevice, receiver_name: str
|
|
) -> Tuple[int, List]:
|
|
"""Receives a set of arrays on the outfeed for the specificied device.
|
|
Args:
|
|
receiver_name: a name used for debugging and logging
|
|
Returns: a tuple with the consumer_id, the arrays received, and
|
|
a kwargs dictionary that was passed to _emit_outfeed.
|
|
"""
|
|
header_shape = xla_client.Shape.array_shape(np.dtype(np.uint32),
|
|
(_OUTFEED_HEADER_LENGTH,))
|
|
|
|
def _get_data(data_shape: XlaShape, device: XlaDevice) -> XlaShape:
|
|
if _OUTFEED_MECHANISM == "device":
|
|
return device.transfer_from_outfeed(data_shape)
|
|
else:
|
|
return xla_client.transfer_from_outfeed(data_shape, device)
|
|
|
|
header = _get_data(header_shape, device)
|
|
if header[0] != _OUTFEED_HEADER_START:
|
|
raise ValueError(f"Read unexpected outfeed header {header[0]} [{receiver_name}]")
|
|
logging.info(f"[{receiver_name}:{device}] Outfeed read header: {header}")
|
|
consumer_id = (header[1] << 32) + header[2]
|
|
metadata_length = header[3]
|
|
assert metadata_length <= _OUTFEED_HEADER_METADATA_LENGTH
|
|
metadatas = [int(header[i]).to_bytes(4, byteorder="big")
|
|
for i in range(4, 4 + (metadata_length + 3) // 4)]
|
|
metadata = b"".join(metadatas)[:metadata_length]
|
|
array_descriptors = msgpack.unpackb(metadata)
|
|
arrays_shape = [xla_client.Shape.array_shape(_CODE_TO_DTYPE[a_descr[0]],
|
|
a_descr[1])
|
|
for a_descr in array_descriptors]
|
|
entire_shape = xla_client.Shape.tuple_shape(arrays_shape)
|
|
arrays = _get_data(entire_shape, device)
|
|
logging.info(f"[{receiver_name}:{device}] Outfeed read data of shape "
|
|
",".join([f"{data.dtype}{data.shape}" for data in arrays]))
|
|
return (consumer_id, arrays)
|
|
|
|
|
|
class TapFunctionException(Exception):
|
|
"""Signals that some tap function had exceptions.
|
|
|
|
Raised by :func:`outfeed_receiver`.
|
|
"""
|
|
pass
|
|
|
|
_outfeed_receiver_started = False
|
|
@contextmanager
|
|
def outfeed_receiver(*,
|
|
timeout_sec=10,
|
|
backends: Optional[Sequence[str]] = None,
|
|
devices: Optional[Sequence[XlaDevice]] = None,
|
|
receiver_name=""):
|
|
# TODO: better timeout management.
|
|
"""Starts receivers for the :func:`id_tap` outfeed from several devices.
|
|
|
|
The receivers will run in a threadpool. The tapped functions will be invoked
|
|
in those threads. If a tap function raises an exception, an error is
|
|
printed, but the receiving continues until the body of the context manager
|
|
terminates and all outfeeds from all devices have been received. Only then
|
|
will a :exc:`TapFunctionException` be raised.
|
|
|
|
Args:
|
|
backends: (optional) sequence of backend names for which to listen.
|
|
Will listen to all devices on those backends. By default, listed to
|
|
all devices on all known backends.
|
|
devices: (optional) sequence of devices to listed to. At most one
|
|
of `backends` or `devices` must be given.
|
|
receiver_name: (optional) a name to use with debug logging
|
|
Usage::
|
|
|
|
with outfeed_receiver():
|
|
jax.jit(func)(args)
|
|
...
|
|
jax.pmap(another_func)(args)
|
|
|
|
The ``outfeed_receiver`` must be started outside any jitted computation.
|
|
|
|
"""
|
|
if not devices:
|
|
backends = backends or xla_client._get_local_backends().keys()
|
|
devices = tuple(itertools.chain(
|
|
*[api.local_devices(api.host_id(backend), backend)
|
|
for backend in backends if backend != "interpreter"]))
|
|
else:
|
|
if backends:
|
|
raise ValueError("At most one of `devices` or `backends` must be given.")
|
|
executor = futures.ThreadPoolExecutor(
|
|
thread_name_prefix=f"outfeed_receiver_{receiver_name}",
|
|
max_workers=len(devices))
|
|
|
|
count_tap_exceptions = 0
|
|
def device_receiver_loop(device: XlaDevice) -> XlaDevice:
|
|
"""Polls the outfeed for a device in a loop."""
|
|
nonlocal count_tap_exceptions
|
|
while (True):
|
|
consumer_id, arrays = _receive_outfeed(device, receiver_name)
|
|
if _LOGGING:
|
|
logging.info(f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} " +
|
|
(" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
|
|
if consumer_id == _end_consumer:
|
|
assert not arrays
|
|
if _LOGGING:
|
|
logging.info(f"[{receiver_name}:{device}] Outfeed received END_OUTFEED")
|
|
return device
|
|
consumer = _consumer_registry_by_id.get(consumer_id)
|
|
if consumer is None:
|
|
logging.error(f"Ignoring received outfeed for unknown tap consumer")
|
|
count_tap_exceptions += 1
|
|
continue # We need to read the entire outfeed
|
|
try:
|
|
arg = api.tree_unflatten(consumer.arg_treedef, arrays)
|
|
consumer.func(arg, **consumer.unpack_kwargs()) # type: ignore[attribute-error]
|
|
except Exception as e:
|
|
logging.error(f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}")
|
|
count_tap_exceptions += 1
|
|
# We continue for now, we need to keep reading the outfeed
|
|
|
|
receiver_futures = [executor.submit(device_receiver_loop, d) for d in devices]
|
|
# Register a callback to raise errors if any. These exception come from
|
|
# bugs in our code, not from the tap functions.
|
|
for rf in receiver_futures:
|
|
rf.add_done_callback(lambda rf: rf.result())
|
|
global _outfeed_receiver_started
|
|
if _outfeed_receiver_started:
|
|
raise ValueError("At most one outfeed_receiver can be running at once.")
|
|
_outfeed_receiver_started = True
|
|
xla.can_execute_outfeed_computations = True
|
|
try:
|
|
yield
|
|
finally:
|
|
for d in devices: # Signal the end of printing
|
|
api.jit(lambda x: id_tap(_end_consumer, None, result=x), device=d)(0) # type: ignore[arg-type]
|
|
xla.can_execute_outfeed_computations = False
|
|
_outfeed_receiver_started = False
|
|
for f in futures.as_completed(receiver_futures, timeout=timeout_sec):
|
|
finished_device = f.result() # Throw exceptions here
|
|
if _LOGGING:
|
|
logging.info(f"[{receiver_name}:{finished_device} Outfeed receiver finished")
|
|
if count_tap_exceptions > 0:
|
|
raise TapFunctionException
|