2020-05-08 17:18:11 +03:00
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2020-10-17 11:15:51 +03:00
|
|
|
"""Primitives for calling from JAX accelerator code to Python functions on the host.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
**Experimental: please give feedback, and expect changes.**
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
This module introduces the host callback functions :func:`call`,
|
|
|
|
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
|
|
|
|
to the host and invoke user-defined Python functions on the host, optionally
|
|
|
|
returning results back to the device computation.
|
|
|
|
|
|
|
|
We show below how these functions can be used. We start with :func:`call`,
|
|
|
|
and we discuss examples of calling from JAX to NumPy CPU custom kernels,
|
|
|
|
or to TensorFlow functions, or to JAX running on another device. In the latter
|
|
|
|
two cases we show how we can support JAX autodiff for the host callbacks,
|
|
|
|
by deferring to the reverse-mode AD on the target platform. Then we
|
|
|
|
show uses of :func:`id_tap` and :func:`id_print`, which have the restriction
|
|
|
|
that they cannot return values from the host to the device.
|
|
|
|
These primitives are generally faster
|
|
|
|
because they are executed asynchronously with the device code and they also
|
|
|
|
support the whole spectrum of JAX transformations. In particular, they can be
|
|
|
|
used to tap into and to debug JAX-transformed code.
|
|
|
|
|
|
|
|
Using :func:`call` to call a host function and return results to device
|
|
|
|
-----------------------------------------------------------------------
|
|
|
|
|
|
|
|
Use :func:`call` to invoke a computation on the host and return
|
|
|
|
NumPy arrays to the device computation.
|
|
|
|
Host computation is useful, e.g., when a device computation needs some data
|
|
|
|
that requires I/O on the host, or it needs a library that is available on the
|
|
|
|
host and you do not want to code it in JAX.
|
2021-01-05 10:31:26 +02:00
|
|
|
For example, eigen decomposition for general matrices in JAX does not work on TPU.
|
2020-10-17 11:15:51 +03:00
|
|
|
We can call the Numpy implementation from any JAX accelerator computation,
|
|
|
|
using a host computation::
|
|
|
|
|
|
|
|
# This function runs on the host
|
|
|
|
def host_eig(m: np.ndarray) -> np.ndarray:
|
|
|
|
return np.linalg.eigvals(m)
|
|
|
|
|
|
|
|
# This function is used in JAX
|
|
|
|
def device_fun(m):
|
|
|
|
# We send "m" to the host, asking it to call "host_eig" and return the result.
|
|
|
|
# We have to specify the result shape and dtype, either in the form of an
|
|
|
|
# example return value or any object that has `shape` and `dtype` attributes,
|
|
|
|
# e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
|
|
|
|
return hcb.call(host_eig, m,
|
|
|
|
# Given an input of shape (..., d, d), eig output has shape (..., d)
|
|
|
|
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
|
|
|
|
|
|
|
|
|
|
|
|
The :func:`call` function and the Python host function both take a single argument
|
|
|
|
and return a single result, but those can be pytrees. Note that we must tell
|
|
|
|
the :func:`call` what shape and dtype to expect from the host invocation, using
|
|
|
|
the ``result_shape`` kwarg.
|
|
|
|
This is important because the device code is compiled with that expectation.
|
|
|
|
There will be an error raised at runtime if the actual invocation produces a
|
|
|
|
different result shape. In general, **such errors and also exceptions raised
|
|
|
|
by the host computation may be difficult to debug**. See the Debugging section
|
|
|
|
below.
|
|
|
|
This is a problem for :func:`call` but not for :func:`id_tap`.
|
|
|
|
|
|
|
|
The :func:`call` API can be used inside a jit or pmap computation or inside
|
|
|
|
cond/scan/while control flow. When used inside :func:`jax.pmap`, there will be
|
|
|
|
separate calls to the host from each of the participating devices::
|
|
|
|
|
|
|
|
def host_sin(x, *, device):
|
|
|
|
print(f"Invoking host_sin with {x.shape} on {device}")
|
|
|
|
return np.sin(x)
|
|
|
|
|
|
|
|
# Use pmap to run the computation on two devices
|
|
|
|
jax.pmap(lambda x: hcb.call(host_sin, x,
|
|
|
|
result_shape=x,
|
|
|
|
# Ask that the `host_sin` function be passed `device=dev`
|
|
|
|
call_with_device=True))(
|
|
|
|
np.ones((2, 4), dtype=np.float32))
|
|
|
|
|
|
|
|
# prints (in arbitrary order)
|
|
|
|
# Invoking host_sin with (4,) on cpu:0
|
|
|
|
# Invoking host_sin with (4,) on cpu:1
|
|
|
|
|
|
|
|
Note that :func:`call` does not (yet) support any JAX transformations, but as we
|
|
|
|
show in the next section one can make use of the
|
|
|
|
existing support for `Custom differentiation in JAX <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_.
|
|
|
|
|
|
|
|
Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff support
|
|
|
|
------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
Another possible use for host computation is to invoke a library written for
|
|
|
|
another framework, such as TensorFlow.
|
|
|
|
In this case it becomes interesting to support JAX autodiff for host callbacks
|
|
|
|
by defering to the autodiff mechanism in TensorFlow,
|
|
|
|
using the :func:`jax.custom_vjp` mechanism.
|
|
|
|
|
|
|
|
This is relatively easy to do, once one understands both the JAX custom VJP
|
|
|
|
and the TensorFlow autodiff mechanisms.
|
|
|
|
The code for how this can be done is shown in the ``call_tf_full_ad``
|
|
|
|
function in `host_callback_to_tf_test.py <https://github.com/google/jax/blob/master/tests/host_callback_to_tf_test.py>`_.
|
|
|
|
This example supports arbitrary higher-order differentiation as well.
|
|
|
|
|
|
|
|
Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support
|
|
|
|
------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
It should not be surprising that we can use host computation to invoke a JAX
|
|
|
|
computation on another device. The arguments are sent from the accelerator to
|
|
|
|
the host, and then to the outside device on which the JAX host
|
|
|
|
computation will run, and then the results are sent back to the original accelerator.
|
|
|
|
|
|
|
|
The code for how this can be done is shown in the ``call_jax_other_device function``
|
|
|
|
in `host_callback_test.py <https://github.com/google/jax/blob/master/tests/host_callback_test.py>`_.
|
|
|
|
|
|
|
|
Using :func:`id_tap` to call a JAX function on another device, with no returned values, but full JAX transformation support
|
|
|
|
---------------------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
The :func:`id_tap` and :func:`id_print` behave like the identity function but have the
|
2020-09-23 13:14:36 +03:00
|
|
|
side-effect of sending the arguments from the device to the host and
|
2020-07-04 18:12:58 +03:00
|
|
|
invoking a user-specified Python function (for :func:`id_tap`) or printing the
|
2020-09-23 13:14:36 +03:00
|
|
|
arguments on the host (for :func:`id_print`). The Python function passed
|
2020-12-13 10:44:20 +02:00
|
|
|
to :func:`id_tap` takes two positional arguments (the value tapped
|
|
|
|
from the device computation along with ``transforms`` sequence,
|
|
|
|
described below). Optionally, the function may be passed a keyword argument
|
|
|
|
``device`` with the Device from which the value was tapped.
|
|
|
|
|
2020-09-23 13:14:36 +03:00
|
|
|
A few examples::
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-09-23 13:14:36 +03:00
|
|
|
# calls func(2x, []) on host and returns 2x
|
2020-07-04 18:12:58 +03:00
|
|
|
y = id_tap(func, 2 * x)
|
2020-09-23 13:14:36 +03:00
|
|
|
# calls func((2x, 3x), []) and returns (2x, 3x)
|
2020-07-04 18:12:58 +03:00
|
|
|
y, z = id_tap(func, (2 * x, 3 * x)) # The argument can be a pytree
|
2020-09-23 13:14:36 +03:00
|
|
|
# calls func(2x, []) and returns y
|
|
|
|
y = id_tap(func, 2 * x, result=y) # override the result of id_tap
|
2020-12-13 10:44:20 +02:00
|
|
|
# calls func(2x, [], device=jax.devices()[0])
|
|
|
|
y = id_tap(func, 2 * x, tap_with_device=True) # Pass the device to the tap
|
2020-09-23 13:14:36 +03:00
|
|
|
# calls func(2x, [], what='activation') and returns 2x
|
|
|
|
y = id_tap(functools.partial(func, what='activation'), 2 * x)
|
|
|
|
# calls func(dict(x=x, y=y), what='data') and returns dict(x=x, y=y)
|
|
|
|
x, y = id_tap(lambda tap, transforms: func(tap, what='data'), dict(x=x, y=y))
|
|
|
|
|
|
|
|
The above examples can all be adapted to use :func:`id_print` instead, with
|
|
|
|
the difference that :func:`id_print` takes one positional argument (to print
|
|
|
|
on the host), the optional kwarg ``result``, and possibly additional kwargs
|
|
|
|
that are also printed along with the automatic kwarg ``transforms``.
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
The order of execution of the callback functions is constrained by data dependency:
|
2020-09-24 14:24:02 +03:00
|
|
|
the arguments are tapped after all the arguments are computed and before the
|
2020-12-13 10:44:20 +02:00
|
|
|
result of the call is used. As of September 2020, it is not strictly necessary
|
|
|
|
anymore for the results of the tap to be used in the rest of the computation.
|
|
|
|
You can just do::
|
|
|
|
|
|
|
|
id_tap(func, x)
|
|
|
|
|
|
|
|
The tap function will execute based on program order. However, if this code
|
|
|
|
is subject to transformations, it is possible for the tap to appear to
|
2020-10-17 11:15:51 +03:00
|
|
|
the transformation as dead code and to be removed from the computation. In
|
|
|
|
that case it is best to use the result of the callback.
|
2020-12-13 10:44:20 +02:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
Behavior under JAX transformations
|
|
|
|
----------------------------------
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
We describe the behaviour under transformations for :func:`id_tap` and
|
|
|
|
:func:`id_print` in the context of the
|
2020-05-08 17:18:11 +03:00
|
|
|
following function definition::
|
|
|
|
|
|
|
|
def power3(x):
|
|
|
|
y = x * x
|
2020-09-23 13:14:36 +03:00
|
|
|
_, y = id_print((x, y), what="x,x^2") # Must pack multiple arguments
|
2020-05-08 17:18:11 +03:00
|
|
|
return y * x
|
|
|
|
|
2020-09-23 13:14:36 +03:00
|
|
|
power3(3.)
|
|
|
|
# what: x,x^2 : [3., 9.]
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
During JAX transformations the special parameter ``transforms`` is added to
|
2020-09-23 13:14:36 +03:00
|
|
|
contain a list of transformation descriptors in the form
|
|
|
|
``(transform_name, transform_params)``.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-23 13:49:27 +03:00
|
|
|
For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended
|
|
|
|
with transformation name ``batch`` and ``batch_dims`` set to the the tuple of
|
|
|
|
batched dimensions (one entry per argument, ``None`` denotes an argument that
|
|
|
|
was broadcast)::
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
jax.vmap(power3)(np.arange(3.))
|
2020-09-23 13:14:36 +03:00
|
|
|
# transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2 : [[0, 1, 2], [0, 1,
|
|
|
|
4]]
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-23 13:49:27 +03:00
|
|
|
For :func:`jax.jvp` there will be two callbacks, one with the values of
|
|
|
|
the primals and one with the tangents::
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
jax.jvp(power3, (3.,), (0.1,))
|
2020-09-23 13:14:36 +03:00
|
|
|
# what: x,x^2: [3., 9.]
|
|
|
|
# transforms: ['jvp'] what: x,x^2 : [0.1, 0.6]
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-23 13:49:27 +03:00
|
|
|
For :func:`jax.vjp` or :func:`jax.grad` there will be one callback with the
|
|
|
|
values of the adjoints for the arguments. You may also see a callback with
|
|
|
|
the values of the primals from the forward pass, if those values are needed for
|
|
|
|
the backward pass::
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
jax.grad(power3)(3.)
|
2020-09-23 13:14:36 +03:00
|
|
|
# what=x,x^2: [3., 9.] # from forward pass, since y is used in backward pass
|
|
|
|
# transforms: ['jvp', 'transpose'] what: x,x^2 : [0., 3.] # from backward pass, adjoints of _, y
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
In presence of :func:`jax.pmap` the code will run on multiple devices and
|
|
|
|
each device will tap its values independently.
|
|
|
|
It may be helpful to use the ``tap_with_device`` option for :func:`id_print`
|
|
|
|
or :func:`id_tap`, so that you see which device is sending which data::
|
|
|
|
|
|
|
|
jax.pmap(power3, devices=jax.devices()[0:2])(np.array([3., 4.])
|
|
|
|
# device=cpu:0 what=x,x^2: [3., 9.] # from the first device
|
|
|
|
# device=cpu:1 what=x,x^2: [4., 16.] # from the second device
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
See documentation for :func:`id_tap` and :func:`id_print`.
|
2020-09-23 13:14:36 +03:00
|
|
|
For more usage example, see tests/host_callback_test.py.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-09-25 15:28:23 +03:00
|
|
|
Low-level details and debugging
|
|
|
|
-------------------------------
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
The host callback functions will be executed for each device in the order in which
|
|
|
|
the send operations were performed on the device.
|
|
|
|
|
|
|
|
The host callback functions for multiple devices may be interleaved.
|
|
|
|
The data from the devices is received by separate threads managed by the JAX
|
|
|
|
runtime (one thread per device). The runtime maintains a buffer of
|
|
|
|
configurable size. When the buffer is full, all the receiving threads are paused
|
|
|
|
which eventually pauses the computation on devices. The runtime has one
|
|
|
|
additional thread that invokes the Python user functions with the received data.
|
|
|
|
If the processing of the callbacks is slow, it may actually lead to the runtime
|
|
|
|
buffer filling up, and eventually pausing the computation on the devices
|
|
|
|
when they need to send something. For more details on the outfeed receiver
|
|
|
|
runtime mechanism see
|
|
|
|
`runtime code
|
|
|
|
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
|
|
|
|
|
|
|
|
In order to pause the execution until all data from computations already
|
|
|
|
started on devices has arrived and has been processed, use :func:`barrier_wait`.
|
|
|
|
Note that this is needed only for :func:`id_tap` and :func:`id_print`, which
|
|
|
|
are processed asyncronously with the device computation.
|
|
|
|
|
|
|
|
Exceptions from the user-defined callback functions are logged along with their
|
|
|
|
stack traces, but the receiving threads are not stopped. Instead the last
|
|
|
|
exception is recorded and the subsequent :func:`barrier_wait` will
|
2021-01-05 10:51:32 +02:00
|
|
|
raise :exc:`CallbackException` if any exception had occurred
|
2020-10-17 11:15:51 +03:00
|
|
|
in one of the tap functions. This exception will include the text and the
|
|
|
|
stack trace of the last exception encountered.
|
|
|
|
|
|
|
|
One further complication arises for callback functions that must return
|
|
|
|
results to the call origin device. In order to avoid the device computation
|
|
|
|
being stuck waiting for a result that will never arrive, in case of any
|
|
|
|
error during the processing of the callback (whether raised by the user-code
|
|
|
|
itself or due to a mismatch of the returned value and the expected return_shape)
|
|
|
|
we send the device a "fake" result of shape ``int8[12345]``. This will make the device
|
|
|
|
computation abort because the received data is different than then one that
|
|
|
|
it expects. On CPU the runtime will crash with a distinctive error message:
|
|
|
|
|
|
|
|
```
|
|
|
|
Check failed: buffer->length() == buffer_length (12345 vs. ...)
|
|
|
|
```
|
|
|
|
|
|
|
|
On GPU, the failure is more user-friendly and will be surfaced to the Python
|
|
|
|
program as:
|
|
|
|
|
|
|
|
```
|
|
|
|
RET_CHECK failure ... Mismatch between infeed source buffer shape s8[12345] ...
|
|
|
|
```
|
|
|
|
|
|
|
|
On TPU, there is currently no shape check for infeed, so we take the safer
|
|
|
|
route to not send anything in case of errors, and let the computation hang.
|
|
|
|
|
|
|
|
The current implementation uses the outfeed mechanism provided by XLA. The
|
|
|
|
mechanism itself is quite primitive in the sense that a receiver must know
|
|
|
|
exactly the shape of each incoming packet, and how many packets are expected.
|
|
|
|
This makes it hard to use for multiple kinds of data in the same computation,
|
|
|
|
and it is practically impossible to use it under conditionals or in loops
|
|
|
|
of non-constant iteration count. Furthermore, code that uses the outfeed
|
|
|
|
mechanism directly cannot be transformed by JAX. All these limitations are
|
|
|
|
addressed by the host callback functions. The tapping API introduced here
|
|
|
|
makes it easy to share the outfeed mechanism for multiple purposes, while
|
|
|
|
supporting all transformations.
|
|
|
|
|
|
|
|
**Note that after you have used the host callback functions, you cannot
|
|
|
|
use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver`
|
|
|
|
if you later need to use lax.outfeed.
|
|
|
|
|
|
|
|
Since the actual calls to your callback functions are made from the C++
|
|
|
|
receiver, it may be hard to debug the calls. In particular, the stack trace
|
|
|
|
will not include the calling code. You can use the flag
|
|
|
|
``jax_inline_host_callback`` (or the environment variable
|
|
|
|
``JAX_INLINE_HOST_CALLBACK``) to ensure that the calls to the callbacks are
|
|
|
|
inlined. This works only if the calls are outside a staging context (``jit``
|
|
|
|
or a control-flow primitive).
|
|
|
|
|
2020-09-25 15:28:23 +03:00
|
|
|
The C++ `receiver
|
|
|
|
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_
|
|
|
|
is started automatically on the first call to :func:`id_tap`. In order to stop
|
|
|
|
it properly, upon start an ``atexit`` handler is registered to call
|
|
|
|
:func:`barrier_wait` with the logging name "at_exit".
|
|
|
|
|
|
|
|
There are a few environment variables that you can use to turn on logging
|
|
|
|
for the C++ outfeed `receiver backend
|
|
|
|
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
|
|
|
|
|
|
|
|
* ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below.
|
|
|
|
* ``TF_CPP_MIN_VLOG_LEVEL=3``: will turn make all VLOG logging up to level 3
|
|
|
|
behave like INFO logs. This may be too much, but you will see which
|
|
|
|
modules are logging relevant info, and then you can select which modules
|
|
|
|
to log from:
|
2020-12-16 09:29:50 +02:00
|
|
|
* `TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
|
|
|
|
Python, without the extension).
|
2020-09-25 15:28:23 +03:00
|
|
|
|
|
|
|
You should also use the ``--verbosity=2`` flag so that you see the logs from Python.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
```
|
2020-10-17 11:15:51 +03:00
|
|
|
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_jit_simple
|
2020-09-25 15:28:23 +03:00
|
|
|
```
|
2020-10-17 11:15:51 +03:00
|
|
|
|
|
|
|
(For blaze tests use --test_arg=--vmodule=...
|
|
|
|
|
|
|
|
Still to do:
|
|
|
|
* More performance tests.
|
|
|
|
* Explore implementation with outside compilation for TPU.
|
|
|
|
* Explore implementation with XLA CustomCall for CPU and GPU.
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
2020-07-04 18:12:58 +03:00
|
|
|
import atexit
|
2020-09-14 02:47:28 -07:00
|
|
|
import functools
|
2020-05-08 17:18:11 +03:00
|
|
|
import itertools
|
2020-10-17 11:15:51 +03:00
|
|
|
import threading
|
|
|
|
import traceback
|
|
|
|
from typing import (Any, Callable, Dict, List, Optional, NamedTuple, Sequence,
|
|
|
|
Tuple, TypeVar, cast)
|
|
|
|
import typing
|
|
|
|
from absl import logging
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
from jax import api
|
|
|
|
from jax import core
|
2020-10-17 11:15:51 +03:00
|
|
|
from jax.config import config, bool_env
|
2020-08-12 09:20:26 +03:00
|
|
|
from jax import custom_derivatives
|
2020-10-17 11:15:51 +03:00
|
|
|
from jax import dtypes
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax import lax
|
2020-05-10 19:54:46 +03:00
|
|
|
from jax.lib import pytree
|
2020-12-13 10:44:20 +02:00
|
|
|
from jax.interpreters import ad, xla, batching, masking, pxla
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax.interpreters import partial_eval as pe
|
2020-11-19 06:41:54 -08:00
|
|
|
from jax._src import pprint_util as ppu
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax import util
|
|
|
|
from jaxlib import xla_client
|
2020-07-04 18:12:58 +03:00
|
|
|
from jaxlib import xla_extension
|
|
|
|
|
2020-05-10 19:54:46 +03:00
|
|
|
import numpy as np
|
2020-10-17 11:15:51 +03:00
|
|
|
|
|
|
|
|
|
|
|
FLAGS = config.FLAGS
|
|
|
|
config.DEFINE_bool(
|
|
|
|
'jax_inline_host_callback',
|
|
|
|
bool_env('JAX_INLINE_HOST_CALLBACK', False),
|
|
|
|
help='Inline the host_callback, if not in a staged context.'
|
|
|
|
)
|
|
|
|
|
|
|
|
def inline_host_callback() -> bool:
|
|
|
|
try:
|
|
|
|
return FLAGS.jax_inline_host_callback
|
|
|
|
except AttributeError:
|
|
|
|
# TODO: I cannot get this flag to be seen for py3.6 tests in Github
|
|
|
|
return False
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
xops = xla_client._xla.ops
|
|
|
|
|
|
|
|
# TODO(necula): fix mypy errors if I define the type aliases below
|
|
|
|
XlaOp = Any # xla_extension.XlaOp
|
2020-07-04 18:12:58 +03:00
|
|
|
XlaShape = Any # xla_client.Shape
|
2020-05-08 17:18:11 +03:00
|
|
|
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
|
|
|
|
XlaDevice = Any # xla_client.Device
|
2020-07-04 18:12:58 +03:00
|
|
|
XlaLocalClient = Any # xla_extension.LocalClient
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
T = TypeVar('T')
|
|
|
|
U = TypeVar('U')
|
|
|
|
_Transforms = Sequence[Tuple[str, Dict[str, Any]]]
|
|
|
|
_TapFunc = Callable[[T, _Transforms], Any]
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
@typing.overload
|
|
|
|
def id_tap(tap_func: _TapFunc, arg: T) -> T:
|
|
|
|
...
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
@typing.overload
|
|
|
|
def id_tap(tap_func: _TapFunc, arg: T, *, result: U) -> U:
|
|
|
|
...
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
@typing.overload
|
|
|
|
def id_tap(tap_func: _TapFunc, arg: T, *, result: U, tap_with_device: bool) -> U:
|
|
|
|
...
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
def id_tap(tap_func, arg, *, result=None, tap_with_device=False, **kwargs):
|
2020-07-04 18:12:58 +03:00
|
|
|
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
**Experimental: please give feedback, and expect changes!**
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
``id_tap`` behaves semantically like the identity function but has the
|
2020-09-14 02:47:28 -07:00
|
|
|
side-effect that a user-defined Python function is called with the runtime
|
2020-09-23 13:14:36 +03:00
|
|
|
value of the argument.
|
2020-05-19 08:23:45 -07:00
|
|
|
|
|
|
|
Args:
|
2020-09-14 02:47:28 -07:00
|
|
|
tap_func: tap function to call like ``tap_func(arg, transforms)``, with
|
2020-09-23 13:14:36 +03:00
|
|
|
``arg`` as described below and where ``transforms`` is the sequence of
|
2020-12-13 10:44:20 +02:00
|
|
|
applied JAX transformations in the form ``(name, params)``. If the
|
|
|
|
`tap_with_device` optional argument is True, then the invocation also
|
|
|
|
includes the device from which the value is tapped as a keyword argument:
|
|
|
|
``tap_func(arg, transforms, device=dev)``.
|
2020-09-14 02:47:28 -07:00
|
|
|
arg: the argument passed to the tap function, can be a pytree of JAX
|
2020-05-19 08:23:45 -07:00
|
|
|
types.
|
2020-09-14 02:47:28 -07:00
|
|
|
result: if given, specifies the return value of ``id_tap``. This value is
|
2020-07-04 18:12:58 +03:00
|
|
|
not passed to the tap function, and in fact is not sent from the device to
|
|
|
|
the host. If the ``result`` parameter is not specified then the return
|
2020-05-24 10:50:07 +03:00
|
|
|
value of ``id_tap`` is ``arg``.
|
2020-12-13 10:44:20 +02:00
|
|
|
tap_with_device: if True then the tap function is invoked with the
|
|
|
|
device from which the tap originates as a keyword argument.
|
2020-05-19 08:23:45 -07:00
|
|
|
|
|
|
|
Returns:
|
2020-09-14 02:47:28 -07:00
|
|
|
``arg``, or ``result`` if given.
|
2020-05-19 08:23:45 -07:00
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
The order of execution is by data dependency: after all the arguments and
|
|
|
|
the value of ``result`` if present, are computed and before the returned
|
|
|
|
value is used. At least one of the returned values of ``id_tap`` must be
|
|
|
|
used in the rest of the computation, or else this operation has no effect.
|
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
Tapping works even for code executed on accelerators and even for code under
|
2020-12-13 10:44:20 +02:00
|
|
|
JAX transformations.
|
2020-05-19 08:23:45 -07:00
|
|
|
|
|
|
|
For more details see the
|
2020-07-04 18:12:58 +03:00
|
|
|
`module documentation
|
2020-10-17 11:15:51 +03:00
|
|
|
<jax.experimental.host_callback.html>`_.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
2020-09-14 02:47:28 -07:00
|
|
|
if kwargs:
|
2020-10-17 11:15:51 +03:00
|
|
|
msg = (
|
|
|
|
"Support for **kwargs in ``id_tap`` has been removed. Instead, "
|
|
|
|
"pre-apply keyword arguments, either by using a closure or by passing "
|
|
|
|
"``functools.partial(tap_func, **kwargs)``.")
|
2021-01-05 10:51:32 +02:00
|
|
|
raise TypeError(msg)
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
_initialize_outfeed_receiver() # Lazy initialization
|
|
|
|
api._check_callable(tap_func)
|
2020-05-08 17:18:11 +03:00
|
|
|
flat_args, arg_treedef = pytree.flatten(arg)
|
2020-07-04 18:12:58 +03:00
|
|
|
for arg in flat_args:
|
|
|
|
api._check_arg(arg)
|
2020-05-08 17:18:11 +03:00
|
|
|
# See definition of id_tap_p for what parameters it takes
|
2020-09-14 02:47:28 -07:00
|
|
|
params = {}
|
2020-07-04 18:12:58 +03:00
|
|
|
params["tap_func_"] = tap_func
|
|
|
|
params["arg_treedef_"] = arg_treedef
|
2020-12-13 10:44:20 +02:00
|
|
|
if tap_with_device:
|
|
|
|
params["tap_with_device_"] = tap_with_device
|
2020-05-08 17:18:11 +03:00
|
|
|
if result is not None:
|
|
|
|
flat_results, result_treedef = pytree.flatten(result)
|
2020-07-04 18:12:58 +03:00
|
|
|
for result in flat_results:
|
|
|
|
api._check_arg(result)
|
2020-10-17 11:15:51 +03:00
|
|
|
flat_outs = id_tap_p.bind(*flat_args, **params) # Returns all_args
|
|
|
|
assert flat_outs
|
|
|
|
flat_tied_results = [id_tap_dep_p.bind(r, flat_outs[0])
|
|
|
|
for r in flat_results]
|
|
|
|
return result_treedef.unflatten(flat_tied_results)
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
2020-09-14 02:47:28 -07:00
|
|
|
flat_outs = id_tap_p.bind(*flat_args, **params)
|
2020-05-08 17:18:11 +03:00
|
|
|
return arg_treedef.unflatten(flat_outs)
|
|
|
|
|
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
def id_print(arg, *, result=None, tap_with_device=False,
|
|
|
|
output_stream=None, threshold=None, **kwargs):
|
2020-05-08 17:18:11 +03:00
|
|
|
"""Like :func:`id_tap` with a printing tap function.
|
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
**Experimental: please give feedback, and expect changes!**
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
On each invocation of the printing tap, the ``kwargs`` if present
|
|
|
|
will be printed first (sorted by keys). Then arg will be printed,
|
|
|
|
with the arrays stringified with ``numpy.array2string``.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
See the :func:`id_tap` documentation.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-19 08:23:45 -07:00
|
|
|
Additional keyword arguments:
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
* ``tap_with_device`` if True, will print also the device from which
|
|
|
|
the value originates.
|
2020-05-19 08:23:45 -07:00
|
|
|
* ``output_stream`` if given then it will be used instead of the
|
2020-07-04 18:12:58 +03:00
|
|
|
built-in ``print``. The string will be passed as
|
|
|
|
``output_stream.write(s)``.
|
2020-05-19 08:23:45 -07:00
|
|
|
* ``threshold`` is passed to ``numpy.array2string``.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
2020-12-13 10:44:20 +02:00
|
|
|
printer = functools.partial(_print_consumer,
|
|
|
|
output_stream=output_stream,
|
|
|
|
threshold=threshold, **kwargs)
|
|
|
|
return id_tap(printer, arg, result=result, tap_with_device=tap_with_device)
|
2020-09-14 02:47:28 -07:00
|
|
|
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
def call(callback_func: Callable, arg, *,
|
|
|
|
result_shape=None,
|
|
|
|
call_with_device=False):
|
|
|
|
"""Make a call to the host, and expect a result.
|
|
|
|
|
|
|
|
**Experimental: please give feedback, and expect changes!**
|
|
|
|
|
|
|
|
Args:
|
|
|
|
callback_func: The Python function to invoke on the host as
|
|
|
|
``callback_func(arg)``. If the ``call_with_device`` optional argument is True,
|
|
|
|
then the invocation also includes the ``device`` kwarg with the device
|
|
|
|
from which the call originates: ``callback_func(arg, device=dev)``. This function
|
|
|
|
must return a pytree of numpy ndarrays.
|
|
|
|
|
|
|
|
arg: the argument passed to the callback function, can be a pytree of JAX
|
|
|
|
types.
|
|
|
|
|
|
|
|
result_shape: a value that describes the expected shape and dtype of the
|
|
|
|
result. This can be a numeric scalar, from which a shape and dtype are
|
|
|
|
obtained, or an object that has ``.shape`` and ``.dtype`` attributes.
|
|
|
|
If the result of the callback is a pytree, then ``result_shape`` should
|
|
|
|
also be a pytree with the same structure. In particular, ``result_shape``
|
|
|
|
can be `()` or `None` if the function does not have any results.
|
|
|
|
The device code containing ``call`` is compiled with the expected result shape and dtype,
|
|
|
|
and an error will be raised at runtime if the actual ``callback_func``
|
|
|
|
invocation returns a different kind of result.
|
|
|
|
|
|
|
|
call_with_device: if True then the callback function is invoked with the
|
|
|
|
device from which the call originates as a keyword argument.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
the result of the ``callback_func`` invocation.
|
|
|
|
|
|
|
|
For more details see the
|
|
|
|
`module documentation
|
|
|
|
<jax.experimental.host_callback.html>`_.
|
|
|
|
"""
|
|
|
|
_initialize_outfeed_receiver() # Lazy initialization
|
|
|
|
api._check_callable(callback_func)
|
|
|
|
flat_args, arg_treedef = pytree.flatten(arg)
|
|
|
|
for arg in flat_args:
|
|
|
|
api._check_arg(arg)
|
|
|
|
# See definition of outside_call_p for what parameters it takes
|
|
|
|
params: Dict[str, Any] = {}
|
|
|
|
params["outside_computation"] = callback_func
|
|
|
|
params["call_with_device"] = call_with_device
|
|
|
|
flat_args_aval = [core.raise_to_shaped(core.get_aval(a)) for a in flat_args]
|
|
|
|
params["arg_treedef"] = arg_treedef
|
|
|
|
params["flat_args_aval"] = tuple(flat_args_aval)
|
|
|
|
|
|
|
|
# Turn abstract values into ShapesDtypeStruct
|
|
|
|
flat_results_shape, result_treedef = pytree.flatten(result_shape)
|
|
|
|
try:
|
|
|
|
flat_results_aval = [core.ShapedArray(np.shape(r), dtypes.result_type(r))
|
|
|
|
for r in flat_results_shape]
|
|
|
|
except Exception:
|
|
|
|
msg = ("result_shape should be a pytree of values with structure "
|
|
|
|
"matching the expected result of the callback function. The "
|
|
|
|
"values must be either numeric scalars, or must have 'shape' and "
|
|
|
|
f"'dtype' attributes. Got {result_shape}")
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
params["result_treedef"] = result_treedef
|
|
|
|
params["flat_results_aval"] = tuple(flat_results_aval)
|
|
|
|
flat_results = outside_call_p.bind(*flat_args, **params)
|
|
|
|
return result_treedef.unflatten(flat_results)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
|
|
|
|
# A registry of outfeed consumers, used upon receiving outfeeds
|
|
|
|
class _ConsumerCallable(NamedTuple):
|
2020-10-17 11:15:51 +03:00
|
|
|
"""Host-side information for an outfeed consumer.
|
|
|
|
|
|
|
|
Must be hashable.
|
|
|
|
"""
|
|
|
|
# All fields are private
|
2020-05-08 17:18:11 +03:00
|
|
|
func: Callable
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms: Tuple[tuple, ...]
|
2020-12-13 10:44:20 +02:00
|
|
|
tap_with_device: bool
|
2020-05-08 17:18:11 +03:00
|
|
|
arg_treedef: Any
|
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
def _unpack_transforms(self) -> Tuple[Tuple[str, Dict[str, Any]], ...]:
|
2020-10-17 11:15:51 +03:00
|
|
|
def _unpack_transform(name, *params):
|
|
|
|
if name == "batch":
|
|
|
|
return name, dict(batch_dims=params[0])
|
|
|
|
elif name == "mask":
|
|
|
|
return name, dict(logical_shapes=5)
|
|
|
|
else:
|
|
|
|
assert not params, f"{name}, {params}"
|
|
|
|
return name, dict()
|
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
return tuple(_unpack_transform(*t) for t in self.transforms)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
def invoke(self, arrays, device):
|
|
|
|
arg = api.tree_unflatten(self.arg_treedef, arrays)
|
2020-12-13 10:44:20 +02:00
|
|
|
if self.tap_with_device:
|
|
|
|
return self.func(arg, self._unpack_transforms(), device=device)
|
|
|
|
else:
|
|
|
|
return self.func(arg, self._unpack_transforms())
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def _register_consumer(cons: _ConsumerCallable) -> int:
|
|
|
|
"""Registers a tap function, cache by hash of cons."""
|
2020-07-04 18:12:58 +03:00
|
|
|
cons_id = _outfeed_receiver.consumer_registry.get(cons)
|
2020-05-08 17:18:11 +03:00
|
|
|
if cons_id is not None:
|
|
|
|
return cons_id
|
2020-07-04 18:12:58 +03:00
|
|
|
cons_id = hash(cons) & 0xFFFFFFFC # pybind11 has trouble here with large ints
|
|
|
|
cons_id += 1 # Reserve the consumer ID 0
|
|
|
|
assert cons_id not in _outfeed_receiver.consumer_registry, (
|
|
|
|
"consumer id collision")
|
|
|
|
_outfeed_receiver.consumer_registry[cons] = cons_id
|
|
|
|
_outfeed_receiver.consumer_registry_by_id[cons_id] = cons
|
2020-05-08 17:18:11 +03:00
|
|
|
return cons_id
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
def _print_consumer(
|
2020-12-13 10:44:20 +02:00
|
|
|
arg, transforms, *, device=None,
|
|
|
|
output_stream=None, threshold=1024, **kwargs):
|
2020-05-08 17:18:11 +03:00
|
|
|
"""The consumer for id_print.
|
|
|
|
|
|
|
|
We provide this as a simple tapping function for printing.
|
2020-07-04 18:12:58 +03:00
|
|
|
This is **experimental** and may not want to add many features to it;
|
|
|
|
it should be easy for the user to roll their own printing function.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
Args:
|
2020-12-13 10:44:20 +02:00
|
|
|
device: the device from which the value originates (only if
|
|
|
|
``tap_with_device`` was used for :func:`id_print`).
|
2020-07-04 18:12:58 +03:00
|
|
|
output_stream: a function whose `write` method is called with the strings to
|
|
|
|
be output.
|
2020-05-08 17:18:11 +03:00
|
|
|
threshold: the value of numpy.array2string threshold parameter.
|
|
|
|
**kwargs: all other keyword args are printed before printing `arg`.
|
|
|
|
"""
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def emit_str(s: str):
|
|
|
|
if output_stream is not None:
|
|
|
|
output_stream.write(s + "\n")
|
|
|
|
else:
|
|
|
|
print(s)
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
if transforms:
|
|
|
|
kwargs['transforms'] = [(name, params) if params else name
|
|
|
|
for name, params in transforms]
|
2020-12-13 10:44:20 +02:00
|
|
|
if device is not None:
|
|
|
|
kwargs['device'] = device
|
2020-07-04 18:12:58 +03:00
|
|
|
kv_pairs = " ".join([
|
|
|
|
f"{k}: {v}" for k, v in sorted(kwargs.items())
|
|
|
|
])
|
2020-05-08 17:18:11 +03:00
|
|
|
if kv_pairs:
|
|
|
|
emit_str(kv_pairs)
|
|
|
|
|
|
|
|
def pp_val(arg) -> ppu.PrettyPrint:
|
2020-10-17 11:15:51 +03:00
|
|
|
if isinstance(arg, tuple):
|
|
|
|
return (
|
|
|
|
ppu.pp("( ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" )"))
|
|
|
|
elif isinstance(arg, list):
|
2020-07-04 18:12:58 +03:00
|
|
|
return (
|
|
|
|
ppu.pp("[ ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" ]"))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif isinstance(arg, dict):
|
2020-07-04 18:12:58 +03:00
|
|
|
return (ppu.pp("{ ") >> ppu.vcat([
|
|
|
|
ppu.pp(f"{k}=") >> pp_val(v) for k, v in sorted(arg.items())
|
|
|
|
]) >> ppu.pp(" }"))
|
2020-05-10 19:54:46 +03:00
|
|
|
elif isinstance(arg, np.ndarray):
|
|
|
|
return ppu.pp(np.array2string(arg, threshold=threshold))
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
|
|
|
return ppu.pp(str(arg))
|
|
|
|
|
|
|
|
emit_str(str(pp_val(arg)))
|
|
|
|
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
def _outside_call_consumer(call_func, expected_result_treedef,
|
|
|
|
expected_flat_results_aval, call_with_device,
|
|
|
|
arg, transforms,
|
|
|
|
*, device):
|
|
|
|
logging.vlog(2, f"Outside call consumer invoking call_func {call_func} with {arg}")
|
|
|
|
try:
|
|
|
|
if call_with_device:
|
|
|
|
res = call_func(arg, device=device)
|
|
|
|
else:
|
|
|
|
res = call_func(arg)
|
|
|
|
|
|
|
|
flat_results, result_treedef = pytree.flatten(res)
|
|
|
|
canonical_flat_results = util.safe_map(xla.canonicalize_dtype, flat_results)
|
|
|
|
flat_results_aval = [core.raise_to_shaped(core.get_aval(r), weak_type=False)
|
|
|
|
for r in canonical_flat_results]
|
|
|
|
logging.vlog(2, f"Outside call consumer {call_func} result {res} : {flat_results_aval}. Sending to infeed.")
|
|
|
|
|
|
|
|
if expected_result_treedef != result_treedef:
|
|
|
|
msg = (f"Callback func {call_func} should have returned a result "
|
|
|
|
f"with pytree {expected_result_treedef} but returned "
|
|
|
|
f"{result_treedef}")
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
|
|
|
if not all(ea.strip_weak_type() == ra.strip_weak_type()
|
|
|
|
for ea, ra in util.safe_zip(expected_flat_results_aval,
|
|
|
|
flat_results_aval)):
|
|
|
|
msg = (f"Callback func {call_func} should have returned a result "
|
|
|
|
"with abstract values "
|
|
|
|
f"{expected_result_treedef.unflatten(expected_flat_results_aval)} "
|
|
|
|
f"but returned {result_treedef.unflatten(flat_results_aval)}")
|
|
|
|
raise TypeError(msg)
|
|
|
|
except Exception as e:
|
|
|
|
# Prepare some results to send in case of error. We are sending something
|
|
|
|
# with a distinctive shape (int8[12345]), one that is unlikely to be what the device
|
|
|
|
# expects. This should have the effect to abort the device computation,
|
|
|
|
# with an error message that we recognize. On TPU there seem to be no
|
|
|
|
# such check, and if we send anything at all the device computation will
|
|
|
|
# use some garbage data. So, on TPU we prefer to not send anything and let
|
|
|
|
# the computation hang.
|
|
|
|
if device.platform == "tpu":
|
|
|
|
canonical_flat_results = None
|
|
|
|
else:
|
|
|
|
canonical_flat_results = [xla.canonicalize_dtype(np.arange(12345, dtype=np.int8))]
|
|
|
|
logging.vlog(2, f"Outside call consumer {call_func} exception {e}. Sending to infeed the error result.")
|
|
|
|
raise e
|
|
|
|
finally:
|
|
|
|
# No matter what, if the device expects results we must send something,
|
|
|
|
# otherwise the device computation hangs forever.
|
|
|
|
# We must transfer the flattened results, as a tuple
|
|
|
|
if expected_flat_results_aval and canonical_flat_results is not None:
|
|
|
|
device.transfer_to_infeed(tuple(canonical_flat_results))
|
|
|
|
|
|
|
|
|
|
|
|
### The id_tap_dep primitive
|
|
|
|
"""
|
|
|
|
The id_tap_dep_p primitive is used to create a dependency of the result of
|
|
|
|
id_tap on the actual tap operation. This is only needed when the
|
|
|
|
id_tap function is used with the `result` parameter. This primitive acts
|
|
|
|
as the identity operator on the first argument.
|
|
|
|
|
|
|
|
For example, given `id_tap(f, (a, b), result=(r, s)`, we convert this to
|
|
|
|
|
|
|
|
a1, b1 = id_tap_p(f, a, b)
|
|
|
|
r1 = id_tap_dep_p(r, a1)
|
|
|
|
s1 = id_tap_dep_p(s, a1)
|
|
|
|
|
|
|
|
There are always two arguments and the result is equal to the first.
|
|
|
|
"""
|
|
|
|
id_tap_dep_p = core.Primitive("id_tap_dep")
|
|
|
|
id_tap_dep_p.multiple_results = False
|
|
|
|
id_tap_dep_p.def_impl(lambda r, _: r)
|
|
|
|
xla.translations[id_tap_dep_p] = lambda comp, a_res, a_tap: a_res
|
|
|
|
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
|
|
|
|
ad.primitive_jvps[id_tap_dep_p] = (
|
|
|
|
lambda primals, tangents: (
|
|
|
|
id_tap_dep_p.bind(primals[0], primals[1]),
|
|
|
|
id_tap_dep_p.bind(tangents[0], tangents[1])))
|
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
|
|
|
|
assert ad.is_undefined_primal(arg_res)
|
|
|
|
assert ad.is_undefined_primal(arg_tap)
|
|
|
|
return (_instantiate_zeros(arg_res, cts), ad.Zero(arg_tap.aval))
|
|
|
|
|
|
|
|
|
|
|
|
ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_dep_batching_rule(batched_args, batch_dims):
|
|
|
|
arg_res, arg_tap = batched_args
|
|
|
|
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
|
|
|
|
|
|
|
|
|
|
|
|
batching.primitive_batchers[id_tap_dep_p] = _id_tap_dep_batching_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_dep_masking_rule(operands, operands_logical_shapes):
|
|
|
|
arg_res, arg_tap = operands
|
|
|
|
return id_tap_dep_p.bind(arg_res, arg_tap)
|
|
|
|
|
|
|
|
|
|
|
|
masking.masking_rules[id_tap_dep_p] = _id_tap_dep_masking_rule
|
|
|
|
|
|
|
|
### The id_tap_p primitive
|
2020-07-04 18:12:58 +03:00
|
|
|
"""The id_tap_p primitive acts like the identity function.
|
|
|
|
|
|
|
|
It has a number of positional arguments. The result of the primitive are
|
|
|
|
the positional arguments.
|
|
|
|
|
|
|
|
The primitive has the following parameters:
|
|
|
|
* tap_func_: the actual (Python) function to invoke with the tapped positional
|
|
|
|
arguments (unflatted according to tapped_args_treedef_) and
|
|
|
|
the parameters that were passed to the id_tap function.
|
2020-10-17 11:15:51 +03:00
|
|
|
* arg_treedef_: the treedef of the tapped positional argument.
|
|
|
|
* tap_with_device_: a boolean that specifies whether the tap function
|
|
|
|
takes an additional device keyword argument.
|
2020-06-02 17:37:20 -07:00
|
|
|
* transforms: a tuple of the transformations that have been applied. Each
|
2020-05-23 13:49:27 +03:00
|
|
|
element of the tuple is itself a tuple with the first element the name
|
2020-06-02 17:37:20 -07:00
|
|
|
of the transform. The remaining elements depend on the transform. For
|
|
|
|
example, for `batch`, the parameters are the dimensions that have been
|
|
|
|
batched, and for `mask` the logical shapes. These are unpacked by
|
2020-05-23 13:49:27 +03:00
|
|
|
_ConsumerCallable before passing to the user function.
|
2020-10-17 11:15:51 +03:00
|
|
|
* has_token_: a boolean, when True it means that the last positional argument
|
|
|
|
is the current token. In this case, the result of the primitive is
|
|
|
|
going to be the non-token positional arguments, along with the updated
|
|
|
|
token. The tokens and this parameter are added after all the JAX
|
|
|
|
transformations, just before staging XLA.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
|
|
|
id_tap_p = core.Primitive("id_tap")
|
|
|
|
id_tap_p.multiple_results = True
|
|
|
|
xla.outfeed_primitives.add(id_tap_p)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
# We use the jitted-version of the primitive even for eager execution, both
|
|
|
|
# so that we do not duplicate logic, but also so that all outfeed is received
|
|
|
|
# by the outfeed_listeners, in the same thread from a given device. If we were
|
|
|
|
# to process the tap here, it would be coming from the main thread. Also,
|
|
|
|
# even in eager execution some primitives, such as while, are compiled.
|
|
|
|
# It would be confusing to process a sequence "id_tap; while" in two
|
|
|
|
# different threads.
|
|
|
|
def _id_tap_impl(*args, tap_func_, arg_treedef_,
|
|
|
|
transforms=(),
|
|
|
|
tap_with_device_=False):
|
|
|
|
if inline_host_callback():
|
|
|
|
callable = _ConsumerCallable(tap_func_, transforms, tap_with_device_, arg_treedef_)
|
|
|
|
callable.invoke(args, api.devices()[0])
|
|
|
|
return args
|
|
|
|
else:
|
|
|
|
return xla.apply_primitive(id_tap_p, *args,
|
|
|
|
arg_treedef_=arg_treedef_,
|
|
|
|
tap_func_=tap_func_,
|
|
|
|
transforms=transforms,
|
|
|
|
tap_with_device_=tap_with_device_)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
id_tap_p.def_impl(_id_tap_impl)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \
|
|
|
|
-> Sequence[pe.AbstractValue]:
|
|
|
|
return args_a
|
|
|
|
|
|
|
|
|
|
|
|
id_tap_p.def_abstract_eval(_id_tap_abstract_eval)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
def _id_tap_translation_rule(comp: XlaComputationBuilder,
|
|
|
|
*args_op: XlaOp,
|
|
|
|
tap_func_=None,
|
|
|
|
arg_treedef_=None,
|
|
|
|
has_token_=False,
|
|
|
|
tap_with_device_=False,
|
|
|
|
transforms=()):
|
|
|
|
# We expect the current token at the end, inserted by _rewrite_jaxpr.
|
|
|
|
assert has_token_
|
|
|
|
current_token = args_op[-1]
|
|
|
|
args_to_outfeed = args_op[:-1] # last args_op is the token
|
|
|
|
|
|
|
|
assert not comp.get_shape(current_token).is_array(), (
|
|
|
|
"The last argument must be a token")
|
|
|
|
consumer_id = _register_consumer(
|
|
|
|
_ConsumerCallable(tap_func_, transforms, tap_with_device_, arg_treedef_))
|
|
|
|
next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token,
|
|
|
|
consumer_id,
|
|
|
|
args_to_outfeed)
|
|
|
|
results = (args_op[:-1] + (next_token,))
|
|
|
|
return xops.Tuple(comp, results)
|
|
|
|
|
|
|
|
|
|
|
|
xla.translations[id_tap_p] = _id_tap_translation_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
|
|
|
|
"""Adds the `transform` to the params["transforms"].
|
|
|
|
|
|
|
|
Uses a tuple representation internally, will be unpacked before the
|
|
|
|
callback by _ConsumerCallable.
|
|
|
|
"""
|
|
|
|
new_transform = (name, *transform_params)
|
|
|
|
return dict(
|
|
|
|
params, transforms=(params.get("transforms", ()) + (new_transform,)))
|
|
|
|
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
# TODO(necula): there must be a better way to do this.
|
|
|
|
# The AttributeError is for regular values, the KeyError is for ConcreteArray
|
|
|
|
def _instantiate_zeros(arg, tan):
|
|
|
|
"""Turn special ad.zero tangents into arrays of 0s."""
|
2020-05-27 18:09:35 +00:00
|
|
|
if type(tan) is not ad.Zero:
|
2020-05-08 17:18:11 +03:00
|
|
|
return tan
|
|
|
|
|
|
|
|
try:
|
|
|
|
aval = arg.aval
|
|
|
|
return ad.instantiate_zeros_aval(aval, tan)
|
|
|
|
except (AttributeError, KeyError):
|
|
|
|
# We get here for regular Python values
|
|
|
|
return ad.zeros_like_jaxval(arg)
|
|
|
|
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
def _instantiate_zeros_aval(aval, tan):
|
|
|
|
"""Turn special ad.zero tangents into arrays of 0s."""
|
|
|
|
if type(tan) is not ad.Zero:
|
|
|
|
return tan
|
|
|
|
|
|
|
|
return ad.instantiate_zeros_aval(aval, tan)
|
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def _id_tap_jvp_rule(primals, tangents, **params):
|
2020-10-17 11:15:51 +03:00
|
|
|
tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents))
|
|
|
|
assert "has_token_" not in params
|
|
|
|
|
|
|
|
arg_treedef = params["arg_treedef_"]
|
|
|
|
# The argument to the jvp tap is a pair of the tapped primals and tangents
|
|
|
|
_, jvp_arg_treedef = api.tree_flatten(
|
|
|
|
(arg_treedef.unflatten(primals),
|
|
|
|
arg_treedef.unflatten(tangent_instantiated)))
|
|
|
|
out_all = id_tap_p.bind(
|
|
|
|
*primals, *tangent_instantiated,
|
|
|
|
**dict(_add_transform(params, "jvp"),
|
|
|
|
arg_treedef_=jvp_arg_treedef))
|
|
|
|
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
|
|
|
|
return tuple(out_primals_tapped), tuple(out_tangents_tapped)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
ad.primitive_jvps[id_tap_p] = _id_tap_jvp_rule
|
|
|
|
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
def _id_tap_partial_eval_rule(trace, *args, **params):
|
|
|
|
# The args have been prepared by the id_tap_jvp_rule: primals, tangents
|
|
|
|
transforms = params.get("transforms", ())
|
|
|
|
if not transforms or transforms[-1] != ("jvp",):
|
|
|
|
# We are not in the process of computing VJP
|
|
|
|
return trace.default_process_primitive(id_tap_p, args, params)
|
|
|
|
|
|
|
|
assert len(args) % 2 == 0
|
|
|
|
nr_primals = len(args) // 2
|
|
|
|
|
|
|
|
consts = [t.pval.get_known() for t in args]
|
|
|
|
if all(c is not None for c in consts):
|
|
|
|
return trace.default_process_primitive(id_tap_p, args, params)
|
|
|
|
# Split into two taps, one for the knowns and one for the unknowns
|
|
|
|
# We implement here only the case when primals are known, and we make a tap
|
|
|
|
# with just the primals.
|
|
|
|
primals, tangents = util.split_list(args, [nr_primals])
|
|
|
|
c_primals_tapped, _ = util.split_list(consts, [nr_primals])
|
|
|
|
assert all([c is not None for c in c_primals_tapped])
|
|
|
|
|
|
|
|
prims, _ = params["arg_treedef_"].unflatten(args)
|
|
|
|
_, primals_treedef = api.tree_flatten(prims)
|
|
|
|
|
|
|
|
outs_known = trace.default_process_primitive(
|
|
|
|
id_tap_p, primals,
|
|
|
|
dict(params,
|
|
|
|
arg_treedef_=primals_treedef,
|
|
|
|
transforms=transforms[:-1]))
|
|
|
|
# Now compute the unknowns using the whole tap, and merge them with the tapped ones
|
|
|
|
outs_all_unknown = trace.default_process_primitive(id_tap_p, args, params)
|
|
|
|
outs_primals_unknown, outs_tangents_unknown = util.split_list(
|
|
|
|
outs_all_unknown, [nr_primals])
|
|
|
|
outs_combined = (
|
|
|
|
[pe.JaxprTracer(trace, pe.PartialVal.known(primal_known),
|
|
|
|
primal_unknown.recipe)
|
|
|
|
for primal_known, primal_unknown in util.safe_zip(outs_known, outs_primals_unknown)] +
|
|
|
|
outs_tangents_unknown)
|
|
|
|
return tuple(outs_combined)
|
|
|
|
|
|
|
|
|
|
|
|
pe.custom_partial_eval_rules[id_tap_p] = _id_tap_partial_eval_rule
|
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def _id_tap_transpose_rule(cts, *args, **params):
|
2020-05-08 17:18:11 +03:00
|
|
|
assert len(cts) == len(args)
|
2020-10-17 11:15:51 +03:00
|
|
|
cts_instantiated = tuple(map(_instantiate_zeros, args, cts))
|
|
|
|
|
|
|
|
# The args have been prepared by the id_tap_jvp_rule: tapped_primals, tapped_tangents, rest_primals, rest_tangents
|
|
|
|
transforms = params.get("transforms", ())
|
|
|
|
if not transforms or transforms[-1] != ("jvp",):
|
|
|
|
# TODO: I should understand better when can this happen. It seems to arise
|
|
|
|
# in scan.
|
|
|
|
return id_tap_p.bind(
|
|
|
|
*cts_instantiated,
|
|
|
|
**_add_transform(params, "transpose"))
|
|
|
|
|
|
|
|
assert len(args) % 2 == 0
|
|
|
|
nr_primals = len(args) // 2
|
|
|
|
|
|
|
|
args_unflat, tan_unflat = params["arg_treedef_"].unflatten(args)
|
|
|
|
_, vjp_arg_treedef = api.tree_flatten(args_unflat)
|
|
|
|
# We want to tap the cts_tapped_tangents
|
|
|
|
cts_primals, cts_tangents = util.split_list(cts_instantiated, [nr_primals])
|
|
|
|
cts_tangents_through_tap = id_tap_p.bind(
|
|
|
|
*cts_tangents,
|
|
|
|
**dict(_add_transform(params, "transpose"),
|
|
|
|
arg_treedef_=vjp_arg_treedef))
|
|
|
|
return (cts_primals + cts_tangents_through_tap)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
ad.primitive_transposes[id_tap_p] = _id_tap_transpose_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_batching_rule(batched_args, batch_dims, **params):
|
2020-05-23 13:49:27 +03:00
|
|
|
new_params = _add_transform(params, "batch", batch_dims)
|
2020-05-08 17:18:11 +03:00
|
|
|
res = id_tap_p.bind(*batched_args, **new_params)
|
|
|
|
return res, batch_dims
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
batching.primitive_batchers[id_tap_p] = _id_tap_batching_rule
|
|
|
|
|
|
|
|
|
|
|
|
def _id_tap_masking_rule(operands, operands_logical_shapes, **params):
|
2020-10-17 11:15:51 +03:00
|
|
|
assert "has_token_" not in params
|
|
|
|
|
|
|
|
assert len(operands) == len(operands_logical_shapes)
|
|
|
|
arg_treedef = params["arg_treedef_"]
|
|
|
|
# We will send the pair of (arg, arg_logical_shapes)
|
|
|
|
packed_operands, packed_arg_tree = api.tree_flatten(
|
|
|
|
(api.tree_unflatten(arg_treedef, operands),
|
|
|
|
api.tree_unflatten(arg_treedef, operands_logical_shapes)))
|
|
|
|
|
|
|
|
packed_results = id_tap_p.bind(*packed_operands,
|
|
|
|
tap_func_=params["tap_func_"],
|
|
|
|
arg_treedef_=packed_arg_tree,
|
|
|
|
transforms=params.get("transforms", ()) + (("mask",),))
|
|
|
|
return packed_results[:len(operands)] + packed_results[len(packed_operands):]
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
masking.masking_rules[id_tap_p] = _id_tap_masking_rule
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
### The outside_call primitive
|
|
|
|
"""
|
|
|
|
This primitive is used to implement the `call` function. It takes several
|
|
|
|
positional arguments that are the flattening of the argument to `call`.
|
|
|
|
It takes the following parameters:
|
|
|
|
|
|
|
|
* outside_computation: the function to invoke with the unflattened arguments.
|
|
|
|
* arg_treedef, flat_args_aval: the treedef and flat list of abstract values
|
|
|
|
for the argument.
|
|
|
|
* result_treedef, flat_results_aval: the treedef and flag list of abstracct
|
|
|
|
value for the expected result.
|
|
|
|
* call_with_device: whether the outside_computation must be invoked with
|
|
|
|
a device keyword argument.
|
|
|
|
"""
|
|
|
|
outside_call_p = core.Primitive("outside_call")
|
|
|
|
outside_call_p.multiple_results = True
|
|
|
|
xla.outfeed_primitives.add(outside_call_p)
|
|
|
|
|
|
|
|
|
|
|
|
def _outside_call_impl(*args, outside_computation,
|
|
|
|
arg_treedef,
|
|
|
|
flat_args_aval,
|
|
|
|
result_treedef,
|
|
|
|
flat_results_aval, call_with_device,
|
|
|
|
**params):
|
|
|
|
if inline_host_callback():
|
|
|
|
arg = arg_treedef.unflatten(args)
|
|
|
|
if call_with_device:
|
|
|
|
res = outside_computation(arg, device=api.devices()[0])
|
|
|
|
else:
|
|
|
|
res = outside_computation(arg)
|
|
|
|
flat_results, result_treedef_actual = api.tree_flatten(res)
|
|
|
|
assert result_treedef_actual == result_treedef, f"expected {result_treedef} but found {result_treedef_actual}"
|
|
|
|
return flat_results
|
|
|
|
else:
|
|
|
|
return xla.apply_primitive(outside_call_p, *args,
|
|
|
|
outside_computation=outside_computation,
|
|
|
|
arg_treedef=arg_treedef,
|
|
|
|
flat_args_aval=flat_args_aval,
|
|
|
|
result_treedef=result_treedef,
|
|
|
|
flat_results_aval=flat_results_aval,
|
|
|
|
call_with_device=call_with_device,
|
|
|
|
**params)
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
outside_call_p.def_impl(_outside_call_impl)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
def _outside_call_abstract_eval(*args_a: pe.AbstractValue,
|
|
|
|
flat_results_aval, **params) -> Sequence[pe.AbstractValue]:
|
|
|
|
if "has_token_" in params and params["has_token_"]:
|
|
|
|
assert len(args_a) >= 2 and args_a[-1] is core.abstract_token and args_a[-2] is core.abstract_token
|
|
|
|
return flat_results_aval + (core.abstract_token, core.abstract_token)
|
|
|
|
else:
|
|
|
|
return flat_results_aval
|
|
|
|
|
|
|
|
|
|
|
|
outside_call_p.def_abstract_eval(_outside_call_abstract_eval)
|
|
|
|
|
|
|
|
|
|
|
|
def _outside_call_translation_rule(
|
|
|
|
comp: XlaComputationBuilder, *args_op: XlaOp,
|
|
|
|
outside_computation,
|
|
|
|
arg_treedef,
|
|
|
|
flat_args_aval,
|
|
|
|
result_treedef=None,
|
|
|
|
flat_results_aval=None,
|
|
|
|
has_token_=False,
|
|
|
|
call_with_device=False):
|
|
|
|
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
|
|
|
|
assert has_token_
|
|
|
|
current_token = args_op[-2]
|
|
|
|
current_itoken = args_op[-1]
|
|
|
|
# TODO: expose shape.is_token
|
|
|
|
assert not comp.get_shape(current_token).is_array() and not comp.get_shape(current_token).is_array(), (
|
|
|
|
"The last two arguments must be tokens")
|
|
|
|
assert not comp.get_shape(current_itoken).is_array() and not comp.get_shape(current_itoken).is_array(), (
|
|
|
|
"The last two arguments must be tokens")
|
|
|
|
|
|
|
|
args_to_outfeed = args_op[:-2]
|
2020-07-04 18:12:58 +03:00
|
|
|
consumer_id = _register_consumer(
|
2020-10-17 11:15:51 +03:00
|
|
|
_ConsumerCallable(functools.partial(_outside_call_consumer,
|
|
|
|
outside_computation, result_treedef,
|
|
|
|
flat_results_aval, call_with_device),
|
|
|
|
(), True, arg_treedef))
|
2020-07-04 18:12:58 +03:00
|
|
|
next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token,
|
|
|
|
consumer_id,
|
|
|
|
args_to_outfeed)
|
2020-10-17 11:15:51 +03:00
|
|
|
if flat_results_aval:
|
|
|
|
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
|
|
|
|
|
|
|
|
results_and_token = xla.translations[lax.infeed_p](comp, after_outfeed_itoken,
|
|
|
|
shapes=flat_results_aval, partitions=None)
|
|
|
|
next_itoken = xops.GetTupleElement(results_and_token, len(flat_results_aval))
|
|
|
|
results = [xops.GetTupleElement(results_and_token, i) for i in range(len(flat_results_aval))]
|
|
|
|
return xops.Tuple(comp, results + [next_token, next_itoken])
|
|
|
|
else:
|
|
|
|
return xops.Tuple(comp, [next_token, current_itoken])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
xla.translations[outside_call_p] = _outside_call_translation_rule
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
####
|
|
|
|
#### Jaxpr rewriting logic to thread the tokens through stateful primitives.
|
|
|
|
####
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
|
2020-09-18 10:07:13 -07:00
|
|
|
def _rewrite_closed_jaxpr(
|
|
|
|
cjaxpr: core.ClosedJaxpr, has_input_token: bool,
|
|
|
|
has_output_token: bool) -> core.ClosedJaxpr:
|
|
|
|
"""Rewrites a ClosedJaxpr to thread the token, if needed."""
|
|
|
|
new_jaxpr = _rewrite_jaxpr(cjaxpr.jaxpr, has_input_token, has_output_token)
|
|
|
|
return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
|
|
|
|
has_output_token: bool) -> core.Jaxpr:
|
2020-05-08 17:18:11 +03:00
|
|
|
"""Rewrite a Jaxpr to thread the token, if needed."""
|
|
|
|
assert has_input_token or not has_output_token
|
|
|
|
|
|
|
|
if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
|
2020-07-04 18:12:58 +03:00
|
|
|
return jaxpr
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-05-26 11:54:14 -07:00
|
|
|
mk_new_var = core.gensym([jaxpr])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
eqns: List[core.JaxprEqn] = []
|
|
|
|
last_token_var = mk_new_var(core.abstract_token) # store the incoming token
|
2020-10-17 11:15:51 +03:00
|
|
|
last_itoken_var = mk_new_var(core.abstract_token) # store the incoming token
|
2020-05-08 17:18:11 +03:00
|
|
|
if has_input_token:
|
2020-10-17 11:15:51 +03:00
|
|
|
invars = jaxpr.invars + [last_token_var, last_itoken_var]
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
|
|
|
invars = jaxpr.invars
|
2020-09-24 14:24:02 +03:00
|
|
|
# We need tokens but none is given in input; make one depending on all invars
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
2020-09-24 14:24:02 +03:00
|
|
|
core.new_jaxpr_eqn(jaxpr.invars, [last_token_var],
|
2020-07-04 18:12:58 +03:00
|
|
|
lax.create_token_p, {}, source_info_util.current()))
|
2020-10-17 11:15:51 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(jaxpr.invars, [last_itoken_var],
|
|
|
|
lax.create_token_p, {}, source_info_util.current()))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
|
|
|
eqns.append(eqn)
|
|
|
|
else:
|
|
|
|
output_token_var = mk_new_var(core.abstract_token)
|
2020-10-17 11:15:51 +03:00
|
|
|
output_itoken_var = mk_new_var(core.abstract_token)
|
|
|
|
_rewrite_eqn(eqn, eqns, last_token_var, output_token_var, last_itoken_var, output_itoken_var, mk_new_var)
|
2020-05-08 17:18:11 +03:00
|
|
|
last_token_var = output_token_var
|
2020-10-17 11:15:51 +03:00
|
|
|
last_itoken_var = output_itoken_var
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
outvars = jaxpr.outvars + ([last_token_var, last_itoken_var] if has_output_token else [])
|
2020-05-24 10:50:07 +03:00
|
|
|
new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
|
2020-07-04 18:12:58 +03:00
|
|
|
return new_jaxpr
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
|
|
|
input_token_var: core.Var, output_token_var: core.Var,
|
2020-10-17 11:15:51 +03:00
|
|
|
input_itoken_var: core.Var, output_itoken_var: core.Var,
|
2020-05-24 10:50:07 +03:00
|
|
|
mk_new_var: Callable[[core.AbstractValue], core.Var]):
|
2020-07-04 18:12:58 +03:00
|
|
|
"""Rewrite an `eqn` and append equations to `eqns`.
|
|
|
|
|
2020-10-16 10:52:56 +03:00
|
|
|
This is only called if the current primitive uses outfeed.
|
2020-07-04 18:12:58 +03:00
|
|
|
Assume that the current token is in `input_token_var` and the resulting
|
|
|
|
token must end in `output_token_var`.
|
2020-12-13 10:44:20 +02:00
|
|
|
|
|
|
|
Append the result of rewriting to `eqns`.
|
2020-07-04 18:12:58 +03:00
|
|
|
"""
|
2020-05-08 17:18:11 +03:00
|
|
|
if eqn.primitive is id_tap_p:
|
2020-07-04 18:12:58 +03:00
|
|
|
assert "has_token_" not in eqn.params
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(eqn.invars + [input_token_var],
|
|
|
|
eqn.outvars + [output_token_var], eqn.primitive,
|
|
|
|
dict(eqn.params, has_token_=True),
|
|
|
|
eqn.source_info))
|
2020-10-17 11:15:51 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn([input_itoken_var],
|
|
|
|
[output_itoken_var], id_p,
|
|
|
|
dict(),
|
|
|
|
eqn.source_info))
|
|
|
|
elif eqn.primitive is outside_call_p:
|
|
|
|
assert "has_token_" not in eqn.params
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(eqn.invars + [input_token_var, input_itoken_var],
|
|
|
|
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
|
|
|
dict(eqn.params, has_token_=True),
|
|
|
|
eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is lax.while_p:
|
2020-07-04 18:12:58 +03:00
|
|
|
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
|
|
|
|
eqn.params,
|
|
|
|
["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
2020-05-08 17:18:11 +03:00
|
|
|
if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
|
2020-05-24 10:50:07 +03:00
|
|
|
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
|
2020-10-17 11:15:51 +03:00
|
|
|
input_itoken_var, output_itoken_var,
|
2020-05-24 10:50:07 +03:00
|
|
|
mk_new_var)
|
|
|
|
return
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.invars + [input_token_var, input_itoken_var],
|
|
|
|
eqn.outvars + [output_token_var, output_itoken_var],
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-09-18 10:07:13 -07:00
|
|
|
body_jaxpr=_rewrite_closed_jaxpr(body_jaxpr, True, True),
|
|
|
|
cond_jaxpr=_rewrite_closed_jaxpr(cond_jaxpr, True,
|
2020-10-17 11:15:51 +03:00
|
|
|
False)), eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is lax.cond_p:
|
2020-05-26 19:32:29 -07:00
|
|
|
branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
|
|
|
|
index, *operands = eqn.invars
|
2020-10-17 11:15:51 +03:00
|
|
|
new_invars = [index, *operands, input_token_var, input_itoken_var]
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
new_invars, eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
2020-07-04 18:12:58 +03:00
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-10-17 11:15:51 +03:00
|
|
|
branches=tuple(_rewrite_closed_jaxpr(jaxpr, True, True)
|
|
|
|
for jaxpr in branches),
|
|
|
|
linear=(*linear, False, False)), eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is lax.scan_p:
|
2020-07-15 11:00:50 -07:00
|
|
|
num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.params,
|
2020-07-15 11:00:50 -07:00
|
|
|
["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length",
|
|
|
|
"unroll"])
|
2020-10-17 11:15:51 +03:00
|
|
|
# We add the tokens right at the end of carry
|
2020-05-08 17:18:11 +03:00
|
|
|
nr_const_and_carry = num_consts + num_carry
|
2020-07-04 18:12:58 +03:00
|
|
|
new_invars = eqn.invars[0:nr_const_and_carry] + [
|
2020-10-17 11:15:51 +03:00
|
|
|
input_token_var, input_itoken_var] + eqn.invars[nr_const_and_carry:]
|
2020-09-18 10:07:13 -07:00
|
|
|
new_jaxpr = _rewrite_closed_jaxpr(carry_jaxpr, True, True)
|
2020-05-08 17:18:11 +03:00
|
|
|
# The rewrite has put the token at end, it has to be at end of carry
|
|
|
|
new_jaxpr_invars = new_jaxpr.jaxpr.invars
|
2020-07-04 18:12:58 +03:00
|
|
|
new_jaxpr_invars = (
|
2020-10-17 11:15:51 +03:00
|
|
|
new_jaxpr_invars[0:nr_const_and_carry] + new_jaxpr_invars[-2:] +
|
|
|
|
new_jaxpr_invars[nr_const_and_carry:-2])
|
2020-05-08 17:18:11 +03:00
|
|
|
new_jaxpr.jaxpr.invars = new_jaxpr_invars
|
|
|
|
|
|
|
|
new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
|
2020-07-04 18:12:58 +03:00
|
|
|
new_jaxpr_outvars = (
|
2020-10-17 11:15:51 +03:00
|
|
|
new_jaxpr_outvars[0:num_carry] + new_jaxpr_outvars[-2:] +
|
|
|
|
new_jaxpr_outvars[num_carry:-2])
|
2020-05-08 17:18:11 +03:00
|
|
|
new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_invars,
|
|
|
|
# Output token is at the end of carry result
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.outvars[0:num_carry] + [output_token_var, output_itoken_var] +
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.outvars[num_carry:],
|
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
|
|
|
jaxpr=new_jaxpr,
|
2020-10-17 11:15:51 +03:00
|
|
|
num_carry=num_carry + 2,
|
|
|
|
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:]),
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
elif eqn.primitive is xla.xla_call_p:
|
2020-06-01 21:45:36 -04:00
|
|
|
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.invars + [input_token_var, input_itoken_var], eqn.outvars + [output_token_var, output_itoken_var],
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-10-16 10:52:56 +03:00
|
|
|
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
2020-10-17 11:15:51 +03:00
|
|
|
donated_invars=eqn.params["donated_invars"] + (False, False)
|
2020-08-13 13:02:22 +03:00
|
|
|
),
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.source_info))
|
2020-12-13 10:44:20 +02:00
|
|
|
elif eqn.primitive is pxla.xla_pmap_p:
|
|
|
|
# We broadcast the input token into an array of tokens
|
|
|
|
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.invars + [input_token_var, input_itoken_var], eqn.outvars + [output_token_var, output_itoken_var],
|
2020-12-13 10:44:20 +02:00
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
|
|
|
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
2020-10-17 11:15:51 +03:00
|
|
|
donated_invars=eqn.params["donated_invars"] + (False, False),
|
2020-12-13 10:44:20 +02:00
|
|
|
# Sharding/unsharding of tokens in pmap_translation are special
|
|
|
|
# cased to just pass-through the token
|
2020-10-17 11:15:51 +03:00
|
|
|
in_axes=eqn.params["in_axes"] + (0, 0),
|
|
|
|
out_axes=eqn.params["out_axes"] + (0, 0)
|
2020-12-13 10:44:20 +02:00
|
|
|
),
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.source_info))
|
2020-10-16 10:52:56 +03:00
|
|
|
elif eqn.primitive is pe.remat_call_p:
|
|
|
|
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.invars + [input_token_var, input_itoken_var],
|
|
|
|
eqn.outvars + [output_token_var, output_itoken_var],
|
2020-10-16 10:52:56 +03:00
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
|
|
|
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
|
|
|
),
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.source_info))
|
2020-08-12 09:20:26 +03:00
|
|
|
elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
|
|
|
|
fun_jaxpr = eqn.params["fun_jaxpr"]
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-08-12 09:20:26 +03:00
|
|
|
def unreachable_thunk():
|
|
|
|
assert False, "Should not be reached"
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-08-12 09:20:26 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.invars + [input_token_var, input_itoken_var],
|
|
|
|
eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
2020-08-12 09:20:26 +03:00
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-09-18 10:07:13 -07:00
|
|
|
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
|
2020-08-12 09:20:26 +03:00
|
|
|
jvp_jaxpr_thunk=unreachable_thunk
|
|
|
|
),
|
|
|
|
eqn.source_info))
|
|
|
|
elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p:
|
|
|
|
fun_jaxpr = eqn.params["fun_jaxpr"]
|
2020-10-17 11:15:51 +03:00
|
|
|
new_invars = [*eqn.invars, input_token_var, input_itoken_var]
|
|
|
|
|
2020-08-12 09:20:26 +03:00
|
|
|
def unreachable_thunk():
|
|
|
|
assert False, "Should not be reached"
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-08-12 09:20:26 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
new_invars, eqn.outvars + [output_token_var, output_itoken_var], eqn.primitive,
|
2020-08-12 09:20:26 +03:00
|
|
|
dict(
|
|
|
|
eqn.params,
|
2020-09-18 10:07:13 -07:00
|
|
|
fun_jaxpr=_rewrite_closed_jaxpr(fun_jaxpr, True, True),
|
2020-08-12 09:20:26 +03:00
|
|
|
fwd_jaxpr_thunk=unreachable_thunk,
|
|
|
|
# The following are illegal values for the parameters, they
|
|
|
|
# should not be needed because this rewrite is just before
|
|
|
|
# compilation to XLA, which does not use those parameters.
|
|
|
|
bwd="illegal param",
|
|
|
|
out_trees="illegal param"
|
|
|
|
),
|
|
|
|
eqn.source_info))
|
2020-12-21 11:01:29 +02:00
|
|
|
elif eqn.primitive is core.named_call_p:
|
|
|
|
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.invars + [input_token_var, input_itoken_var],
|
|
|
|
eqn.outvars + [output_token_var, output_itoken_var],
|
2020-12-21 11:01:29 +02:00
|
|
|
eqn.primitive,
|
|
|
|
dict(
|
|
|
|
eqn.params,
|
|
|
|
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
|
|
|
),
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.source_info))
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
2020-10-17 11:15:51 +03:00
|
|
|
input_token_var: core.Var, output_token_var: core.Var,
|
|
|
|
input_itoken_var: core.Var, output_itoken_var: core.Var,
|
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 19:35:36 -04:00
|
|
|
mk_new_var: Callable):
|
2020-05-24 10:50:07 +03:00
|
|
|
"""Rewrite a while whose cond has outfeed"""
|
|
|
|
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict(
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
|
2020-09-18 10:07:13 -07:00
|
|
|
transformed_cond_jaxpr = _rewrite_closed_jaxpr(cond_jaxpr, True, True)
|
2020-07-04 18:12:58 +03:00
|
|
|
carry_invars = eqn.invars[cond_nconsts + body_nconsts:]
|
2020-10-17 11:15:51 +03:00
|
|
|
# pred1, token1, itoken1 = rewrite(COND)(cond_consts, carry_invars, input_token, input_itoken)
|
2020-07-04 18:12:58 +03:00
|
|
|
pred1_and_token1 = [
|
|
|
|
mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars
|
|
|
|
]
|
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var],
|
2020-07-04 18:12:58 +03:00
|
|
|
pred1_and_token1, xla.xla_call_p,
|
|
|
|
dict(
|
|
|
|
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
|
|
|
name="cond_before",
|
2020-08-13 13:02:22 +03:00
|
|
|
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.source_info))
|
2020-10-17 11:15:51 +03:00
|
|
|
# Make a new cond "lambda pred, carry, token, itoken: pred"
|
2020-05-24 10:50:07 +03:00
|
|
|
new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0])
|
2020-07-04 18:12:58 +03:00
|
|
|
new_cond_invars = ([new_cond_pred_invar] +
|
|
|
|
[mk_new_var(cv.aval) for cv in carry_invars] +
|
2020-10-17 11:15:51 +03:00
|
|
|
[mk_new_var(core.abstract_token), mk_new_var(core.abstract_token)])
|
2020-09-18 10:07:13 -07:00
|
|
|
new_cond_jaxpr = core.ClosedJaxpr(
|
2020-07-04 18:12:58 +03:00
|
|
|
core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), [])
|
|
|
|
# Make a new body:
|
2020-10-17 11:15:51 +03:00
|
|
|
# "lambda cond_constvars, body_constvars, pred, carry, token, itoken:
|
|
|
|
# carry2, token2, itoken2 = rewrite(BODY)(body_constvars, carry, token, itoken)
|
|
|
|
# pred2, token3, itoken3 = rewrite(COND)(cond_constvars, carry2, token2, itoken2)
|
|
|
|
# (pred2, carry2, token3, itoken3)
|
2020-09-18 10:07:13 -07:00
|
|
|
transformed_body_jaxpr = _rewrite_closed_jaxpr(body_jaxpr, True, True)
|
2020-07-04 18:12:58 +03:00
|
|
|
new_body_invars_cond_constvars = [
|
|
|
|
mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]
|
|
|
|
]
|
|
|
|
new_body_invars_body_constvars = [
|
|
|
|
mk_new_var(v.aval)
|
|
|
|
for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts]
|
|
|
|
]
|
2020-05-24 10:50:07 +03:00
|
|
|
new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0])
|
|
|
|
new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars]
|
|
|
|
new_body_invars_token = mk_new_var(core.abstract_token)
|
2020-10-17 11:15:51 +03:00
|
|
|
new_body_invars_itoken = mk_new_var(core.abstract_token)
|
2020-05-24 10:50:07 +03:00
|
|
|
|
|
|
|
new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars]
|
|
|
|
new_body_token2 = mk_new_var(core.abstract_token)
|
2020-10-17 11:15:51 +03:00
|
|
|
new_body_itoken2 = mk_new_var(core.abstract_token)
|
2020-05-24 10:50:07 +03:00
|
|
|
new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0])
|
|
|
|
new_body_token3 = mk_new_var(core.abstract_token)
|
2020-10-17 11:15:51 +03:00
|
|
|
new_body_itoken3 = mk_new_var(core.abstract_token)
|
2020-05-24 10:50:07 +03:00
|
|
|
|
|
|
|
new_body_eqns = [
|
2020-07-04 18:12:58 +03:00
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
new_body_invars_body_constvars + new_body_invars_carry +
|
2020-10-17 11:15:51 +03:00
|
|
|
[new_body_invars_token, new_body_invars_itoken],
|
|
|
|
new_body_carry2 + [new_body_token2, new_body_itoken2],
|
2020-07-04 18:12:58 +03:00
|
|
|
xla.xla_call_p,
|
|
|
|
dict(
|
|
|
|
call_jaxpr=transformed_body_jaxpr.jaxpr,
|
|
|
|
name="body",
|
2020-08-13 13:02:22 +03:00
|
|
|
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.source_info),
|
|
|
|
core.new_jaxpr_eqn(
|
2020-10-17 11:15:51 +03:00
|
|
|
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
|
|
|
|
[new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p,
|
2020-07-04 18:12:58 +03:00
|
|
|
dict(
|
|
|
|
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
|
|
|
name="cond_body",
|
2020-08-13 13:02:22 +03:00
|
|
|
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
|
2020-07-04 18:12:58 +03:00
|
|
|
eqn.source_info)
|
2020-05-24 10:50:07 +03:00
|
|
|
]
|
2020-09-18 10:07:13 -07:00
|
|
|
new_body_jaxpr = core.ClosedJaxpr(
|
2020-07-04 18:12:58 +03:00
|
|
|
core.Jaxpr([], (new_body_invars_cond_constvars +
|
|
|
|
new_body_invars_body_constvars + [new_body_invars_pred] +
|
2020-10-17 11:15:51 +03:00
|
|
|
new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken]),
|
|
|
|
([new_body_pred2] + new_body_carry2 + [new_body_token3, new_body_itoken3]),
|
2020-07-04 18:12:58 +03:00
|
|
|
new_body_eqns), [])
|
2020-05-24 10:50:07 +03:00
|
|
|
|
|
|
|
pred_out = mk_new_var(cond_jaxpr.out_avals[0])
|
2020-07-04 18:12:58 +03:00
|
|
|
eqns.append(
|
|
|
|
core.new_jaxpr_eqn(
|
|
|
|
(eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] +
|
2020-10-17 11:15:51 +03:00
|
|
|
carry_invars + pred1_and_token1[1:]),
|
|
|
|
([pred_out] + eqn.outvars + [output_token_var, output_itoken_var]),
|
|
|
|
lax.while_p,
|
2020-07-04 18:12:58 +03:00
|
|
|
dict(
|
|
|
|
cond_jaxpr=new_cond_jaxpr,
|
|
|
|
cond_nconsts=0,
|
|
|
|
body_jaxpr=new_body_jaxpr,
|
|
|
|
body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
|
2020-05-24 10:50:07 +03:00
|
|
|
|
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
# We need an identity primitive to simplify rewriting
|
|
|
|
id_p = core.Primitive("id")
|
|
|
|
id_p.multiple_results = True
|
|
|
|
id_p.def_impl(lambda *args: args)
|
|
|
|
id_p.def_abstract_eval(lambda *args: args)
|
|
|
|
xla.translations[id_p] = lambda c, *args: xops.Tuple(c, args)
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False)
|
|
|
|
|
|
|
|
|
2021-01-05 10:51:32 +02:00
|
|
|
class CallbackException(Exception):
|
|
|
|
"""Signals that some callback function had exceptions.
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-10-17 11:15:51 +03:00
|
|
|
Raised by :func:`barrier_wait`.
|
2021-01-05 10:51:32 +02:00
|
|
|
See module documentation for details.
|
2020-05-08 17:18:11 +03:00
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2021-01-05 10:51:32 +02:00
|
|
|
TapFunctionException = CallbackException # For backwards compatibility
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
# For now we keep a single outfeed receiver
|
|
|
|
class _OutfeedReceiverData:
|
|
|
|
"""Keep track of the outfeed receiver data."""
|
|
|
|
receiver: Any
|
|
|
|
lock: threading.Lock
|
2021-01-05 10:51:32 +02:00
|
|
|
last_callback_exception: Optional[Tuple[Exception, str]]
|
2020-07-04 18:12:58 +03:00
|
|
|
clients: Tuple[XlaLocalClient, ...]
|
|
|
|
devices: Tuple[XlaDevice, ...]
|
|
|
|
consumer_registry: Dict[_ConsumerCallable, int]
|
|
|
|
consumer_registry_by_id: Dict[int, _ConsumerCallable]
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.receiver = None # Initialize lazily, when first needed
|
|
|
|
self.lock = threading.Lock()
|
2021-01-05 10:51:32 +02:00
|
|
|
self.last_callback_exception = None
|
2020-07-04 18:12:58 +03:00
|
|
|
self.clients = ()
|
|
|
|
self.devices = ()
|
|
|
|
# The consumer registries must be live for the lifetime of the program,
|
|
|
|
# because we may have cached compilations that embed consumer ids, and we
|
|
|
|
# do not want the id reused for other shapes.
|
|
|
|
self.consumer_registry = dict()
|
|
|
|
self.consumer_registry_by_id = dict()
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
"""Wait for all pending outfeeds and stop the receiver."""
|
|
|
|
self.receiver = None # GC will trigger the destructor
|
|
|
|
self.clients = ()
|
|
|
|
self.devices = ()
|
|
|
|
# Do not clear the consumer registries.
|
|
|
|
|
|
|
|
|
|
|
|
_outfeed_receiver = _OutfeedReceiverData()
|
|
|
|
|
|
|
|
|
|
|
|
# This function is called from C++; it must not allow exceptions through.
|
|
|
|
def _outfeed_receiver_callback(device, consumer_id, arrays):
|
2020-10-17 11:15:51 +03:00
|
|
|
# logging.vlog(
|
2020-07-04 18:12:58 +03:00
|
|
|
# 2, f"Outfeed received on device {device} for consumer {consumer_id} " +
|
|
|
|
# (" ".join([f"({a.dtype}{a.shape})" for a in arrays])))
|
|
|
|
consumer = _outfeed_receiver.consumer_registry_by_id.get(consumer_id)
|
|
|
|
assert consumer is not None, "We should have crashed in the runtime"
|
|
|
|
try:
|
2020-10-17 11:15:51 +03:00
|
|
|
consumer.invoke(arrays, device)
|
2021-01-05 10:51:32 +02:00
|
|
|
except Exception as e:
|
|
|
|
formatted_e = traceback.format_exc()
|
|
|
|
logging.error("Postponing exception raised in callback function: %s", formatted_e)
|
|
|
|
_outfeed_receiver.last_callback_exception = (e, formatted_e)
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _initialize_outfeed_receiver(
|
|
|
|
clients: Optional[List[XlaLocalClient]] = None,
|
|
|
|
max_callback_queue_size_bytes: int = int(256 * 1e6)):
|
|
|
|
"""Creates and starts the outfeed_receiver.
|
|
|
|
|
|
|
|
This function is called lazily only when we compile an id_tap.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
* clients: the list of clients (backends) on whose devices to listen on.
|
|
|
|
* max_callback_queue_size_bytes: an optional integer to bound the maximum
|
|
|
|
size of arrays in the callback queue. When this limit is reached the
|
|
|
|
device listener pauses.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
outfeed_receiver_module = xla_extension.outfeed_receiver
|
2020-09-30 01:20:00 +09:00
|
|
|
except AttributeError as err:
|
2020-07-04 18:12:58 +03:00
|
|
|
raise NotImplementedError(
|
2020-09-30 01:20:00 +09:00
|
|
|
"id_tap works only with jaxlib version 0.1.51 and higher") from err
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
with _outfeed_receiver.lock:
|
|
|
|
if _outfeed_receiver.receiver is not None:
|
|
|
|
return
|
|
|
|
|
|
|
|
if clients is None:
|
|
|
|
# By default, all devices on all backends
|
|
|
|
clients = xla_client._get_local_backends().values() # type: ignore[protected-class]
|
|
|
|
# Drop the interpreter clients
|
|
|
|
clients = tuple([c for c in clients if c.platform != "interpreter"]) # type: ignore
|
|
|
|
devices = list(itertools.chain(*[backend.devices() for backend in clients]))
|
|
|
|
_outfeed_receiver.clients = clients # type: ignore[assignment]
|
|
|
|
_outfeed_receiver.devices = devices # type: ignore[assignment]
|
|
|
|
logging.vlog(
|
|
|
|
2, f"Starting outfeed_receiver for {[str(d) for d in devices]}. "
|
2020-10-17 11:15:51 +03:00
|
|
|
f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}")
|
2020-07-04 18:12:58 +03:00
|
|
|
_outfeed_receiver.receiver = outfeed_receiver_module.start(
|
|
|
|
_outfeed_receiver_callback, tuple(clients),
|
|
|
|
max_callback_queue_size_bytes)
|
|
|
|
|
|
|
|
def exit_handler():
|
2020-07-07 11:03:30 +03:00
|
|
|
# Prevent logging usage during compilation, gives errors under pytest
|
|
|
|
xla._on_exit = True
|
2020-09-25 15:28:23 +03:00
|
|
|
barrier_wait("at_exit")
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
atexit.register(exit_handler) # We wait as long as we have callbacks
|
|
|
|
|
|
|
|
|
2020-09-25 15:28:23 +03:00
|
|
|
def barrier_wait(logging_name: Optional[str] = None):
|
2020-07-04 18:12:58 +03:00
|
|
|
"""Blocks the calling thread until all current outfeed is processed.
|
|
|
|
|
|
|
|
Waits until all outfeed from computations already running on all devices
|
|
|
|
has been received and processed by the Python callbacks. Raises
|
2021-01-05 10:51:32 +02:00
|
|
|
Callback if there were exceptions while processing the callbacks.
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
This works by enqueueing a special tap computation to all devices to which
|
|
|
|
we are listening for outfeed. Once all those tap computations are done, we
|
|
|
|
return from barrier_wait.
|
|
|
|
|
|
|
|
Note: If any of the devices are busy and cannot accept new computations,
|
|
|
|
this will deadlock.
|
2020-09-25 15:28:23 +03:00
|
|
|
|
|
|
|
Args:
|
|
|
|
logging_name: an optional string that will be used in the logging statements
|
|
|
|
for this invocation. See `Debugging` in the module documentation.
|
2020-07-04 18:12:58 +03:00
|
|
|
"""
|
2020-09-25 15:28:23 +03:00
|
|
|
logging_name = logging_name or ""
|
|
|
|
logging.vlog(2, f"barrier_wait[{logging_name}]: start")
|
2020-07-04 18:12:58 +03:00
|
|
|
if not _outfeed_receiver.receiver:
|
2020-09-25 15:28:23 +03:00
|
|
|
logging.vlog(2, f"barrier_wait[{logging_name}]: receiver not started")
|
2020-07-04 18:12:58 +03:00
|
|
|
return
|
|
|
|
|
|
|
|
lock = threading.Lock()
|
|
|
|
cv = threading.Condition(lock=lock)
|
|
|
|
num_at_large = len(_outfeed_receiver.devices) # Protected by lock
|
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
def barrier_tap(dev_idx, _):
|
2020-07-04 18:12:58 +03:00
|
|
|
nonlocal num_at_large
|
|
|
|
logging.vlog(
|
2020-09-25 15:28:23 +03:00
|
|
|
2, f"barrier_wait[{logging_name}]: at barrier_tap for device {_outfeed_receiver.devices[dev_idx]} "
|
2020-10-17 11:15:51 +03:00
|
|
|
f". Thread {threading.current_thread()}")
|
2020-07-04 18:12:58 +03:00
|
|
|
with lock:
|
|
|
|
num_at_large -= 1
|
2020-09-25 15:28:23 +03:00
|
|
|
logging.vlog(2, f"barrier_wait[{logging_name}]: still waiting for {num_at_large} barrier_tap")
|
2020-07-04 18:12:58 +03:00
|
|
|
cv.notify()
|
|
|
|
|
|
|
|
for d_idx, d in enumerate(_outfeed_receiver.devices):
|
2020-09-25 15:28:23 +03:00
|
|
|
logging.vlog(2, f"barrier_wait[{logging_name}]: enqueueing barrier on device {d}")
|
2020-07-04 18:12:58 +03:00
|
|
|
x_on_dev = api.device_put(d_idx, device=d)
|
|
|
|
api.jit(lambda x: id_tap(barrier_tap, x), device=d)(x_on_dev)
|
2020-09-25 15:28:23 +03:00
|
|
|
logging.vlog(2, f"barrier_wait[{logging_name}]: waiting for callbacks")
|
2020-07-04 18:12:58 +03:00
|
|
|
with lock:
|
|
|
|
cv.wait_for(lambda: num_at_large == 0)
|
2020-09-25 15:28:23 +03:00
|
|
|
logging.vlog(2, f"barrier_wait[{logging_name}]: done")
|
2021-01-05 10:51:32 +02:00
|
|
|
if _outfeed_receiver.last_callback_exception is not None:
|
|
|
|
last_exception, formatted_last_exception = _outfeed_receiver.last_callback_exception
|
|
|
|
_outfeed_receiver.last_callback_exception = None
|
|
|
|
raise CallbackException(
|
|
|
|
"There were exceptions during callback processing. "
|
|
|
|
f"Last one was: {formatted_last_exception}") from last_exception
|
2020-10-17 11:15:51 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
def stop_outfeed_receiver():
|
|
|
|
"""Stops the outfeed receiver runtime.
|
|
|
|
|
|
|
|
This waits for all outfeeds from computations already running on all devices,
|
|
|
|
and then stops the outfeed receiver runtime. The runtime will be restarted
|
|
|
|
next time you use a tap function.
|
|
|
|
|
|
|
|
It should not be necessary to use this function, unless you want to start
|
|
|
|
using lax.outfeed directly after having used host callbacks.
|
|
|
|
"""
|
|
|
|
_outfeed_receiver.stop()
|