[host_callback] Increase number of threads for callback processing.

Previously there was one thread per device for receiving the outfeed from
devices, but there was a single global thread that was calling into the Python
callbacks. This meant that if one of the callbacks was slow, it was blocking
processing of all other callbacks.

One situation when this created difficulties was if one wanted to break a host_callback into two operations: a quick one to enqueue work on a threadpool,
and a subsequent slow one to wait for and retreive the result. The first slow callback would block all other callbacks, including possibly some quick ones, thus missing the opportunity to start the slow work.

With this change there is a separate queue of outfeeds for each device and a
separate thread per device to call into Python. This allows for concurrency
between callbacks from different devices, although the callbacks from one
device are still sequential. If the programmer wants more concurrency, they can use a threadpool. Having more concurrency by default is tricky, because it may mean that the Python callbacks for one device may be seen out of order.

PiperOrigin-RevId: 385493070
This commit is contained in:
George Necula 2021-07-19 00:17:38 -07:00 committed by jax authors
parent 58522fd8a1
commit a21683605d
2 changed files with 13 additions and 6 deletions

View File

@ -32,6 +32,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
Please upgrade to a supported Python version.
* Breaking changes:
* The host_callback mechnism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called
interleaved. The callbacks corresponding to one device will still be
called in sequence.
## jaxlib 0.1.69 (July 9 2021)
* Fix bugs in TFRT CPU backend that results in incorrect results.

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Primitives for calling from JAX accelerator code to Python functions on the host.
"""Primitives for calling Python functions on the host from JAX accelerator code.
**Experimental: please give feedback, and expect changes.**
@ -338,11 +338,11 @@ runtime (one thread per device). The runtime maintains a buffer of
configurable size (see the flag ``--jax_host_callback_max_queue_byte_size``).
When the buffer is full, all the receiving threads are paused
which eventually pauses the computation on devices. The runtime has one
additional thread that invokes the Python user functions with the received data.
If the processing of the callbacks is slow, it may actually lead to the runtime
buffer filling up, and eventually pausing the computation on the devices
when they need to send something. For more details on the outfeed receiver
runtime mechanism see
additional thread for each device to invoke the Python user functions with the
received data. If the processing of the callbacks is slow, it may actually
lead to the runtime buffer filling up, and eventually pausing the computation
on the devices when they need to send something.
For more details on the outfeed receiver runtime mechanism see
`runtime code
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.