2020-05-08 17:18:11 +03:00
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2020-07-04 18:12:58 +03:00
|
|
|
"""Primitives for calling from accelerators to Python functions on the host.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
**Experimental: please give feedback, and expect changes.**
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
This module introduces the host callback functions :func:`id_tap` and
|
|
|
|
:func:`id_print`, which behave like the identity function but have the
|
2020-09-23 13:14:36 +03:00
|
|
|
side-effect of sending the arguments from the device to the host and
|
2020-07-04 18:12:58 +03:00
|
|
|
invoking a user-specified Python function (for :func:`id_tap`) or printing the
|
2020-09-23 13:14:36 +03:00
|
|
|
arguments on the host (for :func:`id_print`). The Python function passed
|
|
|
|
to :func:`id_tap` takes two positional arguments (the value tapped from the
|
|
|
|
device computation along with ``transforms`` sequence, described below).
|
|
|
|
A few examples::
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-09-23 13:14:36 +03:00
|
|
|
# calls func(2x, []) on host and returns 2x
|
2020-07-04 18:12:58 +03:00
|
|
|
y = id_tap(func, 2 * x)
|
2020-09-23 13:14:36 +03:00
|
|
|
# calls func((2x, 3x), []) and returns (2x, 3x)
|
2020-07-04 18:12:58 +03:00
|
|
|
y, z = id_tap(func, (2 * x, 3 * x)) # The argument can be a pytree
|
2020-09-23 13:14:36 +03:00
|
|
|
# calls func(2x, []) and returns y
|
|
|
|
y = id_tap(func, 2 * x, result=y) # override the result of id_tap
|
|
|
|
# calls func(2x, [], what='activation') and returns 2x
|
|
|
|
y = id_tap(functools.partial(func, what='activation'), 2 * x)
|
|
|
|
# calls func(dict(x=x, y=y), what='data') and returns dict(x=x, y=y)
|
|
|
|
x, y = id_tap(lambda tap, transforms: func(tap, what='data'), dict(x=x, y=y))
|
|
|
|
|
|
|
|
The above examples can all be adapted to use :func:`id_print` instead, with
|
|
|
|
the difference that :func:`id_print` takes one positional argument (to print
|
|
|
|
on the host), the optional kwarg ``result``, and possibly additional kwargs
|
|
|
|
that are also printed along with the automatic kwarg ``transforms``.
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
The order of execution of the tap functions is constrained by data dependency:
|
|
|
|
the arguments are sent after all the arguments are computed and before the
|
|
|
|
result of the call is used. **At least one of the returned values must be
|
|
|
|
used in the rest of the computation, or else this operation has no effect.**
|
|
|
|
The host tap functions will be executed for each device in the order in which
|
|
|
|
the send operations were performed on the device.
|
|
|
|
|
|
|
|
The data from the devices is received by separate threads managed by the JAX
|
|
|
|
runtime (one thread per device). The runtime maintains a buffer of
|
|
|
|
configurable size. When the buffer is full, all the receiving threads are paused
|
|
|
|
which eventually pauses the computation on devices. The runtime has one
|
|
|
|
additional thread that invokes the Python user functions with the received data.
|
|
|
|
If the processing of the callbacks is slow, it may actually lead to the runtime
|
|
|
|
buffer filling up, and eventually pausing the computation on the devices
|
|
|
|
when they need to send something. For more details on the runtime mechanism see
|
|
|
|
`runtime code
|
|
|
|
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
|
|
|
|
|
|
|
|
Exceptions from the user-defined tap functions are logged along with their
|
|
|
|
stack traces, but the receiving threads are not stopped.
|
|
|
|
|
|
|
|
In order to pause the execution until all data from computations already
|
|
|
|
started on devices has arrived and has been processed, use :func:`barrier_wait`.
|
|
|
|
This will also raise :exc:`TapFunctionException` if any exception had occurred
|
2020-05-08 17:18:11 +03:00
|
|
|
in one of the tap functions.
|
|
|
|
|
|
|
|
The current implementation uses the outfeed mechanism provided by XLA. The
|
|
|
|
mechanism itself is quite primitive in the sense that a receiver must know
|
|
|
|
exactly the shape of each incoming packet, and how many packets are expected.
|
|
|
|
This makes it hard to use for multiple kinds of data in the same computation,
|
2020-07-04 18:12:58 +03:00
|
|
|
and it is practically impossible to use it under conditionals or in loops
|
|
|
|
of non-constant iteration count. Furthermore, code that uses the outfeed
|
|
|
|
mechanism directly cannot be transformed by JAX. All these limitations are
|
|
|
|
addressed by the host callback functions. The tapping API introduced here
|
|
|
|
makes it easy to share the outfeed mechanism for multiple purposes, while
|
|
|
|
supporting all transformations.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
**Note that after you have used the host callback functions, you cannot
|
|
|
|
use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver`
|
|
|
|
if you later need to use lax.outfeed.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
We describe the behaviour under transformations in the context of the
|
|
|
|
following function definition::
|
|
|
|
|
|
|
|
def power3(x):
|
|
|
|
y = x * x
|
2020-09-23 13:14:36 +03:00
|
|
|
_, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments
|
2020-05-08 17:18:11 +03:00
|
|
|
return y * x
|
|
|
|
|
2020-09-23 13:14:36 +03:00
|
|
|
power3(3.)
|
|
|
|
# what: x,x^2 : [3., 9.]
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
During JAX transformations the special parameter ``transforms`` is added to
|
2020-09-23 13:14:36 +03:00
|
|
|
contain a list of transformation descriptors in the form
|
|
|
|
``(transform_name, transform_params)``.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-23 13:49:27 +03:00
|
|
|
For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended
|
|
|
|
with transformation name ``batch`` and ``batch_dims`` set to the the tuple of
|
|
|
|
batched dimensions (one entry per argument, ``None`` denotes an argument that
|
|
|
|
was broadcast)::
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
jax.vmap(power3)(np.arange(3.))
|
2020-09-23 13:14:36 +03:00
|
|
|
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1,
|
|
|
|
4]]
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-23 13:49:27 +03:00
|
|
|
For :func:`jax.jvp` there will be two callbacks, one with the values of
|
|
|
|
the primals and one with the tangents::
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
jax.jvp(power3, (3.,), (0.1,))
|
2020-09-23 13:14:36 +03:00
|
|
|
# what: x,x^2: [3., 9.]
|
|
|
|
# transforms: ['jvp'] what: x,x^2 : [0.1, 0.6]
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-23 13:49:27 +03:00
|
|
|
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the
|
|
|
|
values of the adjoints for the arguments. You may also see a callback with
|
|
|
|
the values of the primals from the forward pass, if those values are needed for
|
|
|
|
the backward pass::
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
jax.grad(power3)(3.)
|
2020-09-23 13:14:36 +03:00
|
|
|
# what=x,x^2: [3., 9.] # from forward pass, since y is used in backward pass
|
|
|
|
# transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.] # from backward pass, adjoints of _, y
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
See documentation for :func:`id_tap` and :func:`id_print`.
|
2020-09-23 13:14:36 +03:00
|
|
|
For more usage example, see tests/host_callback_test.py.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
Still to do:
|
|
|
|
* Performance tests.
|
|
|
|
* Add flags for logging.
|
|
|
|
* Add unit tests with mocks.
|
|
|
|
* Explore a simpler API that uses Python program-order, instead of
|
|
|
|
data dependency-order.
|
|
|
|
* Explore implementation with outside compilation.
|
2020-07-04 18:12:58 +03:00
|
|
|
* Explore an extended API that allows the host function to return
|
|
|
|
values to the accelerator computation.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
"""
|
2020-07-04 18:12:58 +03:00
|
|
|
from absl import logging
|
|
|
|
import atexit
|
|
|
|
import contextlib
|
2020-09-14 02:47:28 -07:00
|
|
|
import functools
|
2020-05-08 17:18:11 +03:00
|
|
|
import itertools
|
|
|
|
|
|
|
|
from jax import api
|
|
|
|
from jax import core
|
2020-08-12 09:20:26 +03:00
|
|
|
from jax import custom_derivatives
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax import lax
|
2020-05-10 19:54:46 +03:00
|
|
|
from jax.lib import pytree
|
|
|
|
from jax.interpreters import ad, xla, batching, masking
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax.interpreters import partial_eval as pe
|
|
|
|
from jax import pprint_util as ppu
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
from jax import source_info_util
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax import util
|
|
|
|
from jaxlib import xla_client
|
2020-07-04 18:12:58 +03:00
|
|
|
from jaxlib import xla_extension
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-10 19:54:46 +03:00
|
|
|
import numpy as np
|
2020-07-04 18:12:58 +03:00
|
|
|
import threading
|
2020-05-08 17:18:11 +03:00
|
|
|
import traceback
|
2020-07-04 18:12:58 +03:00
|
|
|
from typing import (Any, Callable, Dict, List, Optional, NamedTuple, Sequence,
|
2020-09-14 02:47:28 -07:00
|
|
|
Tuple, TypeVar, cast)
|
|
|
|
import typing
|
2020-07-04 18:12:58 +03:00
|
|
|
import warnings
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
xops = xla_client._xla.ops
|
|
|
|
|
|
|
|
# TODO(necula): fix mypy errors if I define the type aliases below
|
|
|
|
XlaOp = Any # xla_extension.XlaOp
|
2020-07-04 18:12:58 +03:00
|
|
|
XlaShape = Any # xla_client.Shape
|
2020-05-08 17:18:11 +03:00
|
|
|
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
|
|
|
|
XlaDevice = Any # xla_client.Device
|
2020-07-04 18:12:58 +03:00
|
|
|
XlaLocalClient = Any # xla_extension.LocalClient
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
T = TypeVar('T')
|
|
|
|
U = TypeVar('U')
|
|
|
|
_Transforms = Sequence[Tuple[str, Dict[str, Any]]]
|
|
|
|
_TapFunc = Callable[[T, _Transforms], Any]
|
|
|
|
|
|
|
|
@typing.overload
|
|
|
|
def id_tap(tap_func: _TapFunc, arg: T) -> T:
|
|
|
|
...
|
|
|
|
|
|
|
|
@typing.overload
|
|
|
|
def id_tap(tap_func: _TapFunc, arg: T, *, result: U) -> U:
|
|
|
|
...
|
|
|
|
|
|
|
|
def id_tap(tap_func, arg, *, result=None, **kwargs):
|
2020-07-04 18:12:58 +03:00
|
|
|
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
**Experimental: please give feedback, and expect changes!**
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
``id_tap`` behaves semantically like the identity function but has the
|
2020-09-14 02:47:28 -07:00
|
|
|
side-effect that a user-defined Python function is called with the runtime
|
2020-09-23 13:14:36 +03:00
|
|
|
value of the argument.
|
2020-05-19 08:23:45 -07:00
|
|
|
|
|
|
|
Args:
|
2020-09-14 02:47:28 -07:00
|
|
|
tap_func: tap function to call like ``tap_func(arg, transforms)``, with
|
2020-09-23 13:14:36 +03:00
|
|
|
``arg`` as described below and where ``transforms`` is the sequence of
|
|
|
|
applied JAX transformations in the form ``(name, params)``.
|
2020-09-14 02:47:28 -07:00
|
|
|
arg: the argument passed to the tap function, can be a pytree of JAX
|
2020-05-19 08:23:45 -07:00
|
|
|
types.
|
2020-09-14 02:47:28 -07:00
|
|
|
result: if given, specifies the return value of ``id_tap``. This value is
|
2020-07-04 18:12:58 +03:00
|
|
|
not passed to the tap function, and in fact is not sent from the device to
|
|
|
|
the host. If the ``result`` parameter is not specified then the return
|
2020-05-24 10:50:07 +03:00
|
|
|
value of ``id_tap`` is ``arg``.
|
2020-05-19 08:23:45 -07:00
|
|
|
|
|
|
|
Returns:
|
2020-09-14 02:47:28 -07:00
|
|
|
``arg``, or ``result`` if given.
|
2020-05-19 08:23:45 -07:00
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
The order of execution is by data dependency: after all the arguments and
|
|
|
|
the value of ``result`` if present, are computed and before the returned
|
|
|
|
value is used. At least one of the returned values of ``id_tap`` must be
|
|
|
|
used in the rest of the computation, or else this operation has no effect.
|
|
|
|
|
|
|
|
If you want to tap a constant value, you should use the ``result`` parameter
|
|
|
|
to control when it is tapped, otherwise it will be tapped during tracing
|
|
|
|
of the function::
|
|
|
|
|
|
|
|
x = id_tap(42, result=x)
|
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
Tapping works even for code executed on accelerators and even for code under
|
|
|
|
JAX transformations. Code that uses taps must be run embedded in
|
|
|
|
:func:`outfeed_receiver`.
|
|
|
|
|
|
|
|
For more details see the
|
2020-07-04 18:12:58 +03:00
|
|
|
`module documentation
|
|
|
|
<https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
2020-09-14 02:47:28 -07:00
|
|
|
if kwargs:
|
|
|
|
warnings.warn(
|
|
|
|
"Support for **kwargs in ``id_tap`` is deprecated and will be removed "
|
|
|
|
"in the future. Instead, pre-apply keyword arguments, either by using "
|
|
|
|
"a closure or by passing ``functools.partial(tap_func, **kwargs)`` "
|
|
|
|
"instead.",
|
|
|
|
FutureWarning, stacklevel=2)
|
|
|
|
tap_func = functools.partial(tap_func, **kwargs)
|
2020-07-04 18:12:58 +03:00
|
|
|
_initialize_outfeed_receiver() # Lazy initialization
|
|
|
|
api._check_callable(tap_func)
|
2020-05-08 17:18:11 +03:00
|
|
|
flat_args, arg_treedef = pytree.flatten(arg)
|
2020-07-04 18:12:58 +03:00
|
|
|
for arg in flat_args:
|
|
|
|
api._check_arg(arg)
|
2020-05-08 17:18:11 +03:00
|
|
|
# See definition of id_tap_p for what parameters it takes
|
2020-09-14 02:47:28 -07:00
|
|
|
params = {}
|
2020-07-04 18:12:58 +03:00
|
|
|
params["tap_func_"] = tap_func
|
|
|
|
params["arg_treedef_"] = arg_treedef
|
|
|
|
params["nr_tapped_args_"] = len(flat_args)
|
2020-05-08 17:18:11 +03:00
|
|
|
if result is not None:
|
|
|
|
flat_results, result_treedef = pytree.flatten(result)
|
2020-07-04 18:12:58 +03:00
|
|
|
for result in flat_results:
|
|
|
|
api._check_arg(result)
|
2020-05-08 17:18:11 +03:00
|
|
|
all_args = flat_args + flat_results
|
2020-07-04 18:12:58 +03:00
|
|
|
nr_results = len(flat_results)
|
2020-09-14 02:47:28 -07:00
|
|
|
flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args
|
2020-07-04 18:12:58 +03:00
|
|
|
flat_results = flat_outs[-nr_results:] # type: ignore[unsupported-operands]
|
|
|
|
return result_treedef.unflatten(flat_results)
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
2020-09-14 02:47:28 -07:00
|
|
|
flat_outs = id_tap_p.bind(*flat_args, **params)
|
2020-05-08 17:18:11 +03:00
|
|
|
return arg_treedef.unflatten(flat_outs)
|
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def id_print(arg, *, result=None, output_stream=None, threshold=None, **kwargs):
|
2020-05-08 17:18:11 +03:00
|
|
|
"""Like :func:`id_tap` with a printing tap function.
|
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
**Experimental: please give feedback, and expect changes!**
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
On each invocation of the printing tap, the ``kwargs`` if present
|
|
|
|
will be printed first (sorted by keys). Then arg will be printed,
|
|
|
|
with the arrays stringified with ``numpy.array2string``.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
See the :func:`id_tap` documentation.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
Additional keyword arguments:
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
* ``output_stream`` if given then it will be used instead of the
|
2020-07-04 18:12:58 +03:00
|
|
|
built-in ``print``. The string will be passed as
|
|
|
|
``output_stream.write(s)``.
|
2020-05-19 08:23:45 -07:00
|
|
|
* ``threshold`` is passed to ``numpy.array2string``.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
2020-09-14 02:47:28 -07:00
|
|
|
printer = functools.partial(
|
2020-07-04 18:12:58 +03:00
|
|
|
_print_consumer,
|
|
|
|
output_stream=output_stream,
|
|
|
|
threshold=threshold,
|
2020-09-14 02:47:28 -07:00
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
return id_tap(printer, arg, result=result)
|
|
|
|
|
|
|
|
|
|
|
|
def _unpack_transform(name, *params):
|
|
|
|
if name == "batch":
|
|
|
|
return name, dict(batch_dims=params[0])
|
|
|
|
elif name == "mask":
|
|
|
|
return name, dict(logical_shapes=params[0])
|
|
|
|
else:
|
|
|
|
assert not params, f"{name}, {params}"
|
|
|
|
return name, dict()
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
|
|
|
|
# A registry of outfeed consumers, used upon receiving outfeeds
|
|
|
|
class _ConsumerCallable(NamedTuple):
|
|
|
|
"""Host-side information for an outfeed consumer."""
|
|
|
|
func: Callable
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms: Tuple[tuple, ...]
|
2020-05-08 17:18:11 +03:00
|
|
|
arg_treedef: Any
|
2020-07-04 18:12:58 +03:00
|
|
|
arg_shape: XlaShape # XlaShape implements __hash__.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
def unpack_transforms(self) -> Tuple[Tuple[str, Dict[str, Any]], ...]:
|
|
|
|
return tuple(_unpack_transform(*t) for t in self.transforms)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _register_consumer(cons: _ConsumerCallable) -> int:
|
|
|
|
"""Registers a tap function, cache by hash of cons."""
|
2020-07-04 18:12:58 +03:00
|
|
|
cons_id = _outfeed_receiver.consumer_registry.get(cons)
|
2020-05-08 17:18:11 +03:00
|
|
|
if cons_id is not None:
|
|
|
|
return cons_id
|
2020-07-04 18:12:58 +03:00
|
|
|
cons_id = hash(cons) & 0xFFFFFFFC # pybind11 has trouble here with large ints
|
|
|
|
cons_id += 1 # Reserve the consumer ID 0
|
|
|
|
assert cons_id not in _outfeed_receiver.consumer_registry, (
|
|
|
|
"consumer id collision")
|
|
|
|
_outfeed_receiver.consumer_registry[cons] = cons_id
|
|
|
|
_outfeed_receiver.consumer_registry_by_id[cons_id] = cons
|
2020-05-08 17:18:11 +03:00
|
|
|
return cons_id
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
def _print_consumer(
|
|
|
|
arg, transforms, *, output_stream=None, threshold=1024, **kwargs):
|
2020-05-08 17:18:11 +03:00
|
|
|
"""The consumer for id_print.
|
|
|
|
|
|
|
|
We provide this as a simple tapping function for printing.
|
2020-07-04 18:12:58 +03:00
|
|
|
This is **experimental** and may not want to add many features to it;
|
|
|
|
it should be easy for the user to roll their own printing function.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
Args:
|
2020-07-04 18:12:58 +03:00
|
|
|
output_stream: a function whose `write` method is called with the strings to
|
|
|
|
be output.
|
2020-05-08 17:18:11 +03:00
|
|
|
threshold: the value of numpy.array2string threshold parameter.
|
|
|
|
**kwargs: all other keyword args are printed before printing `arg`.
|
|
|
|
"""
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def emit_str(s: str):
|
|
|
|
if output_stream is not None:
|
|
|
|
output_stream.write(s + "\n")
|
|
|
|
else:
|
|
|
|
print(s)
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
if transforms:
|
|
|
|
kwargs['transforms'] = [(name, params) if params else name
|
|
|
|
for name, params in transforms]
|
2020-07-04 18:12:58 +03:00
|
|
|
kv_pairs = " ".join([
|
|
|
|
f"{k}: {v}" for k, v in sorted(kwargs.items())
|
|
|
|
])
|
2020-05-08 17:18:11 +03:00
|
|
|
if kv_pairs:
|
|
|
|
emit_str(kv_pairs)
|
|
|
|
|
|
|
|
def pp_val(arg) -> ppu.PrettyPrint:
|
|
|
|
if isinstance(arg, (tuple, list)):
|
2020-07-04 18:12:58 +03:00
|
|
|
return (
|
|
|
|
ppu.pp("[ ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" ]"))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif isinstance(arg, dict):
|
2020-07-04 18:12:58 +03:00
|
|
|
return (ppu.pp("{ ") >> ppu.vcat([
|
|
|
|
ppu.pp(f"{k}=") >> pp_val(v) for k, v in sorted(arg.items())
|
|
|
|
]) >> ppu.pp(" }"))
|
2020-05-10 19:54:46 +03:00
|
|
|
elif isinstance(arg, np.ndarray):
|
|
|
|
return ppu.pp(np.array2string(arg, threshold=threshold))
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
|
|
|
return ppu.pp(str(arg))
|
|
|
|
|
|
|
|
emit_str(str(pp_val(arg)))
|
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
"""The id_tap_p primitive acts like the identity function.
|
|
|
|
|
|
|
|
It has a number of positional arguments. The result of the primitive are
|
|
|
|
the positional arguments.
|
|
|
|
|
|
|
|
The primitive has the following parameters:
|
|
|
|
* has_token_: a boolean, when True it means that the last positional argument
|
|
|
|
is the current token. In this case, the result of the primitive is
|
|
|
|
going to be the non-token positional arguments, along with the updated
|
|
|
|
token. The tokens and this parameter are added after all the JAX
|
|
|
|
transformations, just before staging XLA.
|
|
|
|
* nr_tapped_args_: how many positional arguments from the head should be
|
|
|
|
passed to the tap function. The remaining positional arguments are there
|
|
|
|
for data dependency, for implementing the "result" feature, and for
|
|
|
|
the current token.
|
|
|
|
* tapped_args_treedef_: the treedef of the tapped positional arguments.
|
|
|
|
* tap_func_: the actual (Python) function to invoke with the tapped positional
|
|
|
|
arguments (unflatted according to tapped_args_treedef_) and
|
|
|
|
the parameters that were passed to the id_tap function.
|
2020-06-02 17:37:20 -07:00
|
|
|
* transforms: a tuple of the transformations that have been applied. Each
|
2020-05-23 13:49:27 +03:00
|
|
|
element of the tuple is itself a tuple with the first element the name
|
2020-06-02 17:37:20 -07:00
|
|
|
of the transform. The remaining elements depend on the transform. For
|
|
|
|
example, for `batch`, the parameters are the dimensions that have been
|
|
|
|
batched, and for `mask` the logical shapes. These are unpacked by
|
2020-05-23 13:49:27 +03:00
|
|
|
_ConsumerCallable before passing to the user function.
|
2020-07-04 18:12:58 +03:00
|
|
|
* the remaining parameters are from the user's invocation of the id_tap
|
|
|
|
API function and are passed to the tap function.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
|
|
|
id_tap_p = core.Primitive("id_tap")
|
|
|
|
id_tap_p.multiple_results = True
|
|
|
|
xla.outfeed_primitives.add(id_tap_p)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
|
2020-05-23 13:49:27 +03:00
|
|
|
"""Adds the `transform` to the params["transforms"].
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-23 13:49:27 +03:00
|
|
|
Uses a tuple representation internally, will be unpacked before the
|
|
|
|
callback by _ConsumerCallable.
|
|
|
|
"""
|
|
|
|
new_transform = (name, *transform_params)
|
2020-07-04 18:12:58 +03:00
|
|
|
return dict(
|
|
|
|
params, transforms=(params.get("transforms", ()) + (new_transform,)))
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def _id_tap_impl(*arrays, **params):
|
|
|
|
# We use the jitted-version of the primitive even for eager execution, both
|
|
|
|
# so that we do not duplicate logic, but also so that all outfeed is received
|
|
|
|
# by the outfeed_listeners, in the same thread from a given device. If we were
|
|
|
|
# to process the tap here, it would be coming from the main thread. Also,
|
|
|
|
# even in eager execution some primitives, such as while, are compiled.
|
|
|
|
# It would be confusing to process a sequence "id_tap; while" in two
|
|
|
|
# different threads.
|
|
|
|
return xla.apply_primitive(id_tap_p, *arrays, **params)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
id_tap_p.def_impl(_id_tap_impl)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \
|
|
|
|
-> Sequence[pe.AbstractValue]:
|
|
|
|
return args_a
|
|
|
|
|
|
|
|
|
|
|
|
id_tap_p.def_abstract_eval(_id_tap_abstract_eval)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
# TODO(necula): there must be a better way to do this.
|
|
|
|
# The AttributeError is for regular values, the KeyError is for ConcreteArray
|
|
|
|
def _instantiate_zeros(arg, tan):
|
|
|
|
"""Turn special ad.zero tangents into arrays of 0s."""
|
2020-05-27 18:09:35 +00:00
|
|
|
if type(tan) is not ad.Zero:
|
2020-05-08 17:18:11 +03:00
|
|
|
return tan
|
|
|
|
|
|
|
|
try:
|
|
|
|
aval = arg.aval
|
|
|
|
return ad.instantiate_zeros_aval(aval, tan)
|
|
|
|
except (AttributeError, KeyError):
|
|
|
|
# We get here for regular Python values
|
|
|
|
return ad.zeros_like_jaxval(arg)
|
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def _id_tap_jvp_rule(primals, tangents, **params):
|
2020-05-08 17:18:11 +03:00
|
|
|
# Put primals through id_tap separately, so that partial evaluation
|
|
|
|
# can do its job when they are known (for grad)
|
2020-07-04 18:12:58 +03:00
|
|
|
out_primals = id_tap_p.bind(
|
|
|
|
*primals, **params)
|
2020-05-08 17:18:11 +03:00
|
|
|
# Add one primal output as untapped, to create data dependency.
|
|
|
|
tangent_zeros = tuple(map(_instantiate_zeros, primals, tangents))
|
2020-07-04 18:12:58 +03:00
|
|
|
out_tangents_extra = id_tap_p.bind(
|
|
|
|
*tangent_zeros,
|
|
|
|
out_primals[0],
|
|
|
|
**_add_transform(params, "jvp"))
|
2020-05-08 17:18:11 +03:00
|
|
|
return tuple(out_primals), tuple(out_tangents_extra[:-1])
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
ad.primitive_jvps[id_tap_p] = _id_tap_jvp_rule
|
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def _id_tap_transpose_rule(cts, *args, **params):
|
2020-05-08 17:18:11 +03:00
|
|
|
assert len(cts) == len(args)
|
|
|
|
cts_zeros = tuple(map(_instantiate_zeros, args, cts))
|
2020-07-04 18:12:58 +03:00
|
|
|
ct_args = id_tap_p.bind(
|
|
|
|
*cts_zeros,
|
|
|
|
**_add_transform(params, "transpose"))
|
2020-05-08 17:18:11 +03:00
|
|
|
return ct_args
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
ad.primitive_transposes[id_tap_p] = _id_tap_transpose_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_batching_rule(batched_args, batch_dims, **params):
|
2020-05-23 13:49:27 +03:00
|
|
|
new_params = _add_transform(params, "batch", batch_dims)
|
2020-05-08 17:18:11 +03:00
|
|
|
res = id_tap_p.bind(*batched_args, **new_params)
|
|
|
|
return res, batch_dims
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
batching.primitive_batchers[id_tap_p] = _id_tap_batching_rule
|
|
|
|
|
|
|
|
# def _id_tap_shape_rule(*operands, **params):
|
|
|
|
# return tuple([op.shape for op in operands])
|
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_masking_rule(operands, operands_logical_shapes, **params):
|
2020-07-04 18:12:58 +03:00
|
|
|
new_params = _add_transform(params, "mask", operands_logical_shapes)
|
2020-05-08 17:18:11 +03:00
|
|
|
return id_tap_p.bind(*operands, **new_params)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
masking.masking_rules[id_tap_p] = _id_tap_masking_rule
|
|
|
|
|
|
|
|
####
|
|
|
|
#### XLA compilation ####
|
|
|
|
####
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_translation_rule(comp: XlaComputationBuilder,
|
|
|
|
*args_op: XlaOp,
|
|
|
|
tap_func_=None,
|
|
|
|
nr_tapped_args_,
|
|
|
|
arg_treedef_=None,
|
|
|
|
has_token_=False,
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms=()):
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
# We expect the current token at the end, inserted by _rewrite_jaxpr.
|
2020-07-04 18:12:58 +03:00
|
|
|
assert has_token_
|
2020-05-08 17:18:11 +03:00
|
|
|
current_token = args_op[-1]
|
2020-07-04 18:12:58 +03:00
|
|
|
assert not comp.get_shape(current_token).is_array(), (
|
|
|
|
"The last argument must be a token")
|
|
|
|
|
|
|
|
args_to_outfeed = args_op[0:nr_tapped_args_]
|
|
|
|
consumer_id = _register_consumer(
|
2020-09-14 02:47:28 -07:00
|
|
|
_ConsumerCallable(tap_func_, transforms, arg_treedef_,
|
2020-07-04 18:12:58 +03:00
|
|
|
comp.get_shape(xops.Tuple(comp, args_to_outfeed))))
|
|
|
|
next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token,
|
|
|
|
consumer_id,
|
|
|
|
args_to_outfeed)
|
2020-05-08 17:18:11 +03:00
|
|
|
results = (args_op[:-1] + (next_token,))
|
|
|
|
return xops.Tuple(comp, results)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
xla.translations[id_tap_p] = _id_tap_translation_rule
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
####
|
|
|
|
#### Jaxpr rewriting logic to thread the tokens through stateful primitives.
|
|
|
|
####
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def _rewrite_closed_jaxpr(
|
|
|
|
cjaxpr: core.ClosedJaxpr, has_input_token: bool,
|
|
|
|
has_output_token: bool) -> core.ClosedJaxpr:
|
|
|
|
"""Rewrites a ClosedJaxpr to thread the token, if needed."""
|
|
|
|
new_jaxpr = _rewrite_jaxpr(cjaxpr.jaxpr, has_input_token, has_output_token)
|
|
|
|
return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
|
|
|
has_output_token: bool) -> core.Jaxpr:
|
2020-05-08 17:18:11 +03:00
|
|
|
"""Rewrite a Jaxpr to thread the token, if needed."""
|
|
|
|
assert has_input_token or not has_output_token
|
|
|
|
|
|
|
|
if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
|
2020-07-04 18:12:58 +03:00
|
|
|
return jaxpr
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-26 11:54:14 -07:00
|
|
|
mk_new_var = core.gensym([jaxpr])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
eqns: List[core.JaxprEqn] = []
|
|
|
|
last_token_var = mk_new_var(core.abstract_token) # store the incoming token
|
|
|
|
if has_input_token:
|
|
|
|
invars = jaxpr.invars + [last_token_var]
|
|
|
|
else:
|
|
|
|
invars = jaxpr.invars
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var],
|
|
|
|
lax.create_token_p, {}, source_info_util.current()))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
|
|
|
eqns.append(eqn)
|
|
|
|
else:
|
|
|
|
output_token_var = mk_new_var(core.abstract_token)
|
2020-05-24 10:50:07 +03:00
|
|
|
_rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var)
|
2020-05-08 17:18:11 +03:00
|
|
|
last_token_var = output_token_var
|
|
|
|
|
|
|
|
outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
|
2020-05-24 10:50:07 +03:00
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
|
2020-07-04 18:12:58 +03:00
|
|
|
return new_jaxpr
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
|
|
|
input_token_var: core.Var, output_token_var: core.Var,
|
2020-05-24 10:50:07 +03:00
|
|
|
mk_new_var: Callable[[core.AbstractValue], core.Var]):
|
2020-07-04 18:12:58 +03:00
|
|
|
"""Rewrite an `eqn` and append equations to `eqns`.
|
|
|
|
|
|
|
|
Assume that the current token is in `input_token_var` and the resulting
|
|
|
|
token must end in `output_token_var`.
|
|
|
|
"""
|
2020-05-08 17:18:11 +03:00
|
|
|
if eqn.primitive is id_tap_p:
|
2020-07-04 18:12:58 +03:00
|
|
|
assert "has_token_" not in eqn.params
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(eqn.invars + [input_token_var],
|
|
|
|
eqn.outvars + [output_token_var], eqn.primitive,
|
|
|
|
dict(eqn.params, has_token_=True),
|
|
|
|
eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is lax.while_p:
|
2020-07-04 18:12:58 +03:00
|
|
|
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
|
|
|
|
eqn.params,
|
|
|
|
["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
2020-05-08 17:18:11 +03:00
|
|
|
if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
2020-05-24 10:50:07 +03:00
|
|
|
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
|
|
|
|
mk_new_var)
|
|
|
|
return
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
eqn.invars + [input_token_var], eqn.outvars + [output_token_var],
|
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-09-18 10:07:13 -07:00
|
|
|
body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True),
|
|
|
|
cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True,
|
2020-07-04 18:12:58 +03:00
|
|
|
False)), eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is lax.cond_p:
|
2020-05-26 19:32:29 -07:00
|
|
|
branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
|
|
|
|
index, *operands = eqn.invars
|
|
|
|
new_invars = [index, *operands, input_token_var]
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_invars, eqn.outvars + [output_token_var], eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
|
|
|
branches=tuple(
|
2020-09-18 10:07:13 -07:00
|
|
|
_rewrite_closed_jaxpr(jaxpr, True, True)
|
2020-07-04 18:12:58 +03:00
|
|
|
for jaxpr in branches),
|
|
|
|
linear=(*linear, False)), eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is lax.scan_p:
|
2020-07-15 11:00:50 -07:00
|
|
|
num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.params,
|
2020-07-15 11:00:50 -07:00
|
|
|
["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length",
|
|
|
|
"unroll"])
|
2020-05-08 17:18:11 +03:00
|
|
|
# We add the token right at the end of carry
|
|
|
|
nr_const_and_carry = num_consts + num_carry
|
2020-07-04 18:12:58 +03:00
|
|
|
new_invars = eqn.invars[0:nr_const_and_carry] + [
|
|
|
|
input_token_var
|
|
|
|
] + eqn.invars[nr_const_and_carry:]
|
2020-09-18 10:07:13 -07:00
|
|
|
new_jaxpr = _rewrite_closed_jaxpr(carry_jaxpr, True, True)
|
2020-05-08 17:18:11 +03:00
|
|
|
# The rewrite has put the token at end, it has to be at end of carry
|
|
|
|
new_jaxpr_invars = new_jaxpr.jaxpr.invars
|
2020-07-04 18:12:58 +03:00
|
|
|
new_jaxpr_invars = (
|
|
|
|
new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] +
|
|
|
|
new_jaxpr_invars[nr_const_and_carry:-1])
|
2020-05-08 17:18:11 +03:00
|
|
|
new_jaxpr.jaxpr.invars = new_jaxpr_invars
|
|
|
|
|
|
|
|
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
|
2020-07-04 18:12:58 +03:00
|
|
|
new_jaxpr_outvars = (
|
|
|
|
new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] +
|
|
|
|
new_jaxpr_outvars[num_carry:-1])
|
2020-05-08 17:18:11 +03:00
|
|
|
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_invars,
|
|
|
|
# Output token is at the end of carry result
|
|
|
|
eqn.outvars[0:num_carry] + [output_token_var] +
|
|
|
|
eqn.outvars[num_carry:],
|
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
|
|
|
jaxpr=new_jaxpr,
|
|
|
|
num_carry=num_carry + 1,
|
|
|
|
linear=linear + (False,)),
|
|
|
|
eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is xla.xla_call_p:
|
2020-06-01 21:45:36 -04:00
|
|
|
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
eqn.invars + [input_token_var], eqn.outvars + [output_token_var],
|
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
|
|
|
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True,
|
2020-08-13 13:02:22 +03:00
|
|
|
True),
|
|
|
|
donated_invars=eqn.params["donated_invars"] + (False,)
|
|
|
|
),
|
|
|
|
eqn.source_info))
|
2020-08-12 09:20:26 +03:00
|
|
|
elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
|
|
|
|
fun_jaxpr = eqn.params["fun_jaxpr"]
|
|
|
|
new_invars = [*eqn.invars, input_token_var]
|
|
|
|
def unreachable_thunk():
|
|
|
|
assert False, "Should not be reached"
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_invars, eqn.outvars + [output_token_var], eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-09-18 10:07:13 -07:00
|
|
|
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
|
2020-08-12 09:20:26 +03:00
|
|
|
jvp_jaxpr_thunk=unreachable_thunk
|
|
|
|
),
|
|
|
|
eqn.source_info))
|
|
|
|
elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
|
|
|
|
fun_jaxpr = eqn.params["fun_jaxpr"]
|
|
|
|
new_invars = [*eqn.invars, input_token_var]
|
|
|
|
def unreachable_thunk():
|
|
|
|
assert False, "Should not be reached"
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_invars, eqn.outvars + [output_token_var], eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-09-18 10:07:13 -07:00
|
|
|
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
|
2020-08-12 09:20:26 +03:00
|
|
|
fwd_jaxpr_thunk=unreachable_thunk,
|
|
|
|
# The following are illegal values for the parameters, they
|
|
|
|
# should not be needed because this rewrite is just before
|
|
|
|
# compilation to XLA, which does not use those parameters.
|
|
|
|
bwd="illegal param",
|
|
|
|
out_trees="illegal param"
|
|
|
|
),
|
|
|
|
eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
input_token_var: core.Var,
|
|
|
|
output_token_var: core.Var,
|
|
|
|
mk_new_var: Callable):
|
2020-05-24 10:50:07 +03:00
|
|
|
"""Rewrite a while whose cond has outfeed"""
|
|
|
|
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
2020-09-18 10:07:13 -07:00
|
|
|
transformed_cond_jaxpr = _rewrite_closed_jaxpr(cond_jaxpr, True, True)
|
2020-07-04 18:12:58 +03:00
|
|
|
carry_invars = eqn.invars[cond_nconsts + body_nconsts:]
|
2020-05-24 10:50:07 +03:00
|
|
|
# pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token)
|
2020-07-04 18:12:58 +03:00
|
|
|
pred1_and_token1 = [
|
|
|
|
mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars
|
|
|
|
]
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var],
|
|
|
|
pred1_and_token1, xla.xla_call_p,
|
|
|
|
dict(
|
|
|
|
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
|
|
|
name="cond_before",
|
2020-08-13 13:02:22 +03:00
|
|
|
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.source_info))
|
2020-05-24 10:50:07 +03:00
|
|
|
# Make a new cond "lambda pred, carry, token: pred"
|
|
|
|
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
|
2020-07-04 18:12:58 +03:00
|
|
|
new_cond_invars = ([new_cond_pred_invar] +
|
|
|
|
[mk_new_var(cv.aval) for cv in carry_invars] +
|
|
|
|
[mk_new_var(core.abstract_token)])
|
2020-09-18 10:07:13 -07:00
|
|
|
new_cond_jaxpr = core.ClosedJaxpr(
|
2020-07-04 18:12:58 +03:00
|
|
|
core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
|
|
|
|
# Make a new body:
|
|
|
|
# "lambda cond_constvars, body_constvars, pred, carry, token:
|
|
|
|
# carry2, token2 = rewrite(BODY)(body_constvars, carry, token)
|
|
|
|
# pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2)
|
|
|
|
# (pred2, carry2, token3)
|
2020-09-18 10:07:13 -07:00
|
|
|
transformed_body_jaxpr = _rewrite_closed_jaxpr(body_jaxpr, True, True)
|
2020-07-04 18:12:58 +03:00
|
|
|
new_body_invars_cond_constvars = [
|
|
|
|
mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]
|
|
|
|
]
|
|
|
|
new_body_invars_body_constvars = [
|
|
|
|
mk_new_var(v.aval)
|
|
|
|
for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts]
|
|
|
|
]
|
2020-05-24 10:50:07 +03:00
|
|
|
new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0])
|
|
|
|
new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars]
|
|
|
|
new_body_invars_token = mk_new_var(core.abstract_token)
|
|
|
|
|
|
|
|
new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars]
|
|
|
|
new_body_token2 = mk_new_var(core.abstract_token)
|
|
|
|
new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0])
|
|
|
|
new_body_token3 = mk_new_var(core.abstract_token)
|
|
|
|
|
|
|
|
new_body_eqns = [
|
2020-07-04 18:12:58 +03:00
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_body_invars_body_constvars + new_body_invars_carry +
|
|
|
|
[new_body_invars_token], new_body_carry2 + [new_body_token2],
|
|
|
|
xla.xla_call_p,
|
|
|
|
dict(
|
|
|
|
call_jaxpr=transformed_body_jaxpr.jaxpr,
|
|
|
|
name="body",
|
2020-08-13 13:02:22 +03:00
|
|
|
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.source_info),
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2],
|
|
|
|
[new_body_pred2, new_body_token3], xla.xla_call_p,
|
|
|
|
dict(
|
|
|
|
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
|
|
|
name="cond_body",
|
2020-08-13 13:02:22 +03:00
|
|
|
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.source_info)
|
2020-05-24 10:50:07 +03:00
|
|
|
]
|
2020-09-18 10:07:13 -07:00
|
|
|
new_body_jaxpr = core.ClosedJaxpr(
|
2020-07-04 18:12:58 +03:00
|
|
|
core.Jaxpr([], (new_body_invars_cond_constvars +
|
|
|
|
new_body_invars_body_constvars + [new_body_invars_pred] +
|
|
|
|
new_body_invars_carry + [new_body_invars_token]),
|
|
|
|
([new_body_pred2] + new_body_carry2 + [new_body_token3]),
|
|
|
|
new_body_eqns), [])
|
2020-05-24 10:50:07 +03:00
|
|
|
|
|
|
|
pred_out = mk_new_var(cond_jaxpr.out_avals[0])
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
(eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] +
|
|
|
|
carry_invars + [pred1_and_token1[1]]),
|
|
|
|
([pred_out] + eqn.outvars + [output_token_var]), lax.while_p,
|
|
|
|
dict(
|
|
|
|
cond_jaxpr=new_cond_jaxpr,
|
|
|
|
cond_nconsts=0,
|
|
|
|
body_jaxpr=new_body_jaxpr,
|
|
|
|
body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
|
2020-05-24 10:50:07 +03:00
|
|
|
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
|
|
|
|
|
|
|
|
|
|
|
|
class TapFunctionException(Exception):
|
|
|
|
"""Signals that some tap function had exceptions.
|
|
|
|
|
|
|
|
Raised by :func:`outfeed_receiver`.
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def outfeed_receiver():
|
|
|
|
"""Implements a barrier after a block of code.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
DEPRECATED:
|
|
|
|
This function is not necessary anymore, it is here for backwards compatiblity.
|
|
|
|
At the moment it implements a ``barrier_wait`` after the body of the
|
|
|
|
context manager finishes.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
2020-07-04 18:12:58 +03:00
|
|
|
warnings.warn(
|
|
|
|
"outfeed_receiver is unnecessary and deprecated. In the latest "
|
|
|
|
"version the outfeer receiver mechanism is started automatically. Use "
|
|
|
|
"barrier_wait if instead you want to wait for outfeeds after "
|
|
|
|
"a computation", DeprecationWarning)
|
|
|
|
_initialize_outfeed_receiver()
|
|
|
|
# We will deprecate the outfeed_receiver context manager, but for now
|
|
|
|
# we just turn it into a barrier.
|
2020-05-08 17:18:11 +03:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
2020-07-04 18:12:58 +03:00
|
|
|
# We put a barrier, which will also raise the TapFunctionException
|
|
|
|
barrier_wait()
|
|
|
|
|
|
|
|
|
|
|
|
# For now we keep a single outfeed receiver
|
|
|
|
class _OutfeedReceiverData:
|
|
|
|
"""Keep track of the outfeed receiver data."""
|
|
|
|
receiver: Any
|
|
|
|
lock: threading.Lock
|
|
|
|
num_tap_exceptions: int
|
|
|
|
clients: Tuple[XlaLocalClient, ...]
|
|
|
|
devices: Tuple[XlaDevice, ...]
|
|
|
|
consumer_registry: Dict[_ConsumerCallable, int]
|
|
|
|
consumer_registry_by_id: Dict[int, _ConsumerCallable]
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.receiver = None # Initialize lazily, when first needed
|
|
|
|
self.lock = threading.Lock()
|
|
|
|
self.num_tap_exceptions = 0
|
|
|
|
self.clients = ()
|
|
|
|
self.devices = ()
|
|
|
|
# The consumer registries must be live for the lifetime of the program,
|
|
|
|
# because we may have cached compilations that embed consumer ids, and we
|
|
|
|
# do not want the id reused for other shapes.
|
|
|
|
self.consumer_registry = dict()
|
|
|
|
self.consumer_registry_by_id = dict()
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
"""Wait for all pending outfeeds and stop the receiver."""
|
|
|
|
self.receiver = None # GC will trigger the destructor
|
|
|
|
self.clients = ()
|
|
|
|
self.devices = ()
|
|
|
|
# Do not clear the consumer registries.
|
|
|
|
|
|
|
|
|
|
|
|
_outfeed_receiver = _OutfeedReceiverData()
|
|
|
|
|
|
|
|
|
|
|
|
# This function is called from C++; it must not allow exceptions through.
|
|
|
|
def _outfeed_receiver_callback(device, consumer_id, arrays):
|
|
|
|
#logging.vlog(
|
|
|
|
# 2, f"Outfeed received on device {device} for consumer {consumer_id} " +
|
|
|
|
# (" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
|
|
|
|
consumer = _outfeed_receiver.consumer_registry_by_id.get(consumer_id)
|
|
|
|
assert consumer is not None, "We should have crashed in the runtime"
|
|
|
|
try:
|
|
|
|
arg = api.tree_unflatten(consumer.arg_treedef, arrays)
|
2020-09-14 02:47:28 -07:00
|
|
|
consumer.func(arg, consumer.unpack_transforms()) # type: ignore[attribute-error]
|
2020-07-04 18:12:58 +03:00
|
|
|
except Exception as e:
|
2020-09-14 02:47:28 -07:00
|
|
|
if isinstance(e, TypeError):
|
|
|
|
logging.error("The signature host_callback.id_tap uses to calls wrapped "
|
|
|
|
"functions has changed: ``transforms`` was previously "
|
|
|
|
"passed as a keyword argument, but is now passed by "
|
|
|
|
"position.")
|
2020-07-04 18:12:58 +03:00
|
|
|
logging.error("Postponing exception raised in tap function: %s\n%s", str(e),
|
|
|
|
traceback.format_exc())
|
|
|
|
_outfeed_receiver.num_tap_exceptions += 1
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize_outfeed_receiver(
|
|
|
|
clients: Optional[List[XlaLocalClient]] = None,
|
|
|
|
max_callback_queue_size_bytes: int = int(256 * 1e6)):
|
|
|
|
"""Creates and starts the outfeed_receiver.
|
|
|
|
|
|
|
|
This function is called lazily only when we compile an id_tap.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
* clients: the list of clients (backends) on whose devices to listen on.
|
|
|
|
* max_callback_queue_size_bytes: an optional integer to bound the maximum
|
|
|
|
size of arrays in the callback queue. When this limit is reached the
|
|
|
|
device listener pauses.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
outfeed_receiver_module = xla_extension.outfeed_receiver
|
|
|
|
except AttributeError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"id_tap works only with jaxlib version 0.1.51 and higher")
|
|
|
|
|
|
|
|
with _outfeed_receiver.lock:
|
|
|
|
if _outfeed_receiver.receiver is not None:
|
|
|
|
return
|
|
|
|
|
|
|
|
if clients is None:
|
|
|
|
# By default, all devices on all backends
|
|
|
|
clients = xla_client._get_local_backends().values() # type: ignore[protected-class]
|
|
|
|
# Drop the interpreter clients
|
|
|
|
clients = tuple([c for c in clients if c.platform != "interpreter"]) # type: ignore
|
|
|
|
devices = list(itertools.chain(*[backend.devices() for backend in clients]))
|
|
|
|
_outfeed_receiver.clients = clients # type: ignore[assignment]
|
|
|
|
_outfeed_receiver.devices = devices # type: ignore[assignment]
|
|
|
|
logging.vlog(
|
|
|
|
2, f"Starting outfeed_receiver for {[str(d) for d in devices]}. "
|
|
|
|
f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}")
|
|
|
|
_outfeed_receiver.receiver = outfeed_receiver_module.start(
|
|
|
|
_outfeed_receiver_callback, tuple(clients),
|
|
|
|
max_callback_queue_size_bytes)
|
|
|
|
|
|
|
|
def exit_handler():
|
2020-07-07 11:03:30 +03:00
|
|
|
# Prevent logging usage during compilation, gives errors under pytest
|
|
|
|
xla._on_exit = True
|
2020-07-04 18:12:58 +03:00
|
|
|
logging.vlog(2, "Barrier wait atexit")
|
|
|
|
barrier_wait()
|
|
|
|
|
|
|
|
atexit.register(exit_handler) # We wait as long as we have callbacks
|
|
|
|
|
|
|
|
|
|
|
|
def barrier_wait():
|
|
|
|
"""Blocks the calling thread until all current outfeed is processed.
|
|
|
|
|
|
|
|
Waits until all outfeed from computations already running on all devices
|
|
|
|
has been received and processed by the Python callbacks. Raises
|
|
|
|
TapFunctionException if there were exceptions while processing the callbacks.
|
|
|
|
|
|
|
|
This works by enqueueing a special tap computation to all devices to which
|
|
|
|
we are listening for outfeed. Once all those tap computations are done, we
|
|
|
|
return from barrier_wait.
|
|
|
|
|
|
|
|
Note: If any of the devices are busy and cannot accept new computations,
|
|
|
|
this will deadlock.
|
|
|
|
"""
|
|
|
|
logging.vlog(2, "barrier_wait: start")
|
|
|
|
if not _outfeed_receiver.receiver:
|
|
|
|
logging.vlog(2, "barrier_wait: receiver not started")
|
|
|
|
return
|
|
|
|
|
|
|
|
lock = threading.Lock()
|
|
|
|
cv = threading.Condition(lock=lock)
|
|
|
|
num_at_large = len(_outfeed_receiver.devices) # Protected by lock
|
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
def barrier_tap(dev_idx, _):
|
2020-07-04 18:12:58 +03:00
|
|
|
nonlocal num_at_large
|
|
|
|
logging.vlog(
|
|
|
|
2, f"barrier_wait: thread {threading.current_thread()} for "
|
|
|
|
f"device {_outfeed_receiver.devices[dev_idx]} at barrier_tap")
|
|
|
|
with lock:
|
|
|
|
num_at_large -= 1
|
|
|
|
cv.notify()
|
|
|
|
|
|
|
|
for d_idx, d in enumerate(_outfeed_receiver.devices):
|
|
|
|
logging.vlog(2, f"barrier_wait: enqueueing barrier on device {d}")
|
|
|
|
x_on_dev = api.device_put(d_idx, device=d)
|
|
|
|
api.jit(lambda x: id_tap(barrier_tap, x), device=d)(x_on_dev)
|
|
|
|
logging.vlog(2, "barrier_wait: waiting for calblacks")
|
|
|
|
with lock:
|
|
|
|
cv.wait_for(lambda: num_at_large == 0)
|
|
|
|
logging.vlog(2, "Done barrier_wait")
|
|
|
|
if _outfeed_receiver.num_tap_exceptions > 0:
|
|
|
|
_outfeed_receiver.num_tap_exceptions = 0
|
|
|
|
raise TapFunctionException(
|
|
|
|
"There were exceptions during id_tap processing.")
|
|
|
|
|
|
|
|
def stop_outfeed_receiver():
|
|
|
|
"""Stops the outfeed receiver runtime.
|
|
|
|
|
|
|
|
This waits for all outfeeds from computations already running on all devices,
|
|
|
|
and then stops the outfeed receiver runtime. The runtime will be restarted
|
|
|
|
next time you use a tap function.
|
|
|
|
|
|
|
|
It should not be necessary to use this function, unless you want to start
|
|
|
|
using lax.outfeed directly after having used host callbacks.
|
|
|
|
"""
|
|
|
|
_outfeed_receiver.stop()
|