From 4f3011f3204dc9b36ed4d9de883ea31c1a6bad40 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 4 Jul 2020 18:12:58 +0300 Subject: [PATCH] Refactored host_callback to use the C++ runtime. (#3644) * Refactored host_callback to use the C++ runtime. * The new runtime makes it unnecessary to start the outfeed_receiver in the user's code * We don't need msgpack anymore * There is an interaction between host_callback and using lax.outfeed. I am trying to solve this by (a) making host_callback_test stop the outfeed receiver on finish and infeed_test on start, and (b) telling pytest-xdist to run all the tests from one file into a single worker. --- .github/workflows/ci-build.yaml | 4 +- build/test-requirements.txt | 1 - docs/requirements.txt | 3 +- jax/api.py | 2 +- jax/experimental/host_callback.py | 1008 +++++++++++++++-------------- jax/interpreters/pxla.py | 8 +- jax/interpreters/xla.py | 27 +- pytest.ini | 7 +- tests/host_callback_test.py | 566 +++++++++------- tests/infeed_test.py | 3 + 10 files changed, 866 insertions(+), 763 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index cf6ecc2b7..d66748090 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -101,5 +101,5 @@ jobs: pip install -r docs/requirements.txt - name: Test documentation run: | - pytest docs - pytest --doctest-modules jax/api.py + pytest -n 1 docs + pytest -n 1 --doctest-modules jax/api.py diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 8c3b38bc3..98379c97a 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,6 +1,5 @@ flake8 jaxlib==0.1.51 -msgpack mypy==0.770 pytest-benchmark pytest-xdist diff --git a/docs/requirements.txt b/docs/requirements.txt index b3cd58f74..3562b06e6 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,12 +5,11 @@ ipykernel nbsphinx sphinx-autodoc-typehints myst-parser[sphinx] -# For host_callback.py -msgpack # The next packages are for notebooks matplotlib sklearn # For CI tests. pytest +pytest-xdist # Must install jax itself for notebook execution to work . diff --git a/jax/api.py b/jax/api.py index a6e96f9ff..ed50c5b16 100644 --- a/jax/api.py +++ b/jax/api.py @@ -354,7 +354,7 @@ def xla_computation(fun: Callable, jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals, instantiate=instantiate_const_outputs, stage_out=True) - jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr) + jaxpr = xla.apply_outfeed_rewriter(jaxpr) axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr)) c = xb.make_computation_builder('xla_computation_{}'.format(fun_name)) xla_consts = map(partial(xb.constant, c), consts) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index a63ec3bb7..ad755d87b 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -11,57 +11,67 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Primitives for calling user-defined Python functions -on the host, even from compiled and transformed code. +"""Primitives for calling from accelerators to Python functions on the host. **Experimental: please give feedback, and expect changes.** -Host callbacks work even for code executed on accelerators and -even for code under JAX transformations. A few examples:: +This module introduces the host callback functions :func:`id_tap` and +:func:`id_print`, which behave like the identity function but have the +side-effect of sending the arguments from the accelerator to the host and +invoking a user-specified Python function (for :func:`id_tap`) or printing the +arguments on the host (for :func:`id_print`). A few examples:: - # calls func(2x) and returns 2x - y = id_tap(func, x * 2) - # calls func((2x, 3x)) and returns (2x, 3x) - y, z = id_tap(func, (x * 2, x * 3)) # The argument can be a pytree - # calls func(2x) and returns y - y = id_tap(func, x * 2, result=y) - # calls func(2x, what='x') and returns 2x - y = id_tap(func, x * 2, what='x') - # calls func(dict(x=x, y=y), what='foo') and returns dict(x=x, y=y) - x, y = id_tap(func, dict(x=x, y=y), what='a dict') + # call func(2x) on host and return 2x + y = id_tap(func, 2 * x) + # call func((2x, 3x)) and return (2x, 3x) + y, z = id_tap(func, (2 * x, 3 * x)) # The argument can be a pytree + # call func(2x) and return y + y = id_tap(func, 2 * x, result=y) + # call func(2x, what='activation') and return 2x + y = id_tap(func, 2 * x, what='activation') + # call func(dict(x=x, y=y), what='data') and return dict(x=x, y=y) + x, y = id_tap(func, dict(x=x, y=y), what='data') -The order of execution is by data dependency: after all the arguments are -computed and before the result is used. At least one of the returned values -must be used in the rest of the computation, or else this operation has no effect. +The order of execution of the tap functions is constrained by data dependency: +the arguments are sent after all the arguments are computed and before the +result of the call is used. **At least one of the returned values must be +used in the rest of the computation, or else this operation has no effect.** +The host tap functions will be executed for each device in the order in which +the send operations were performed on the device. -**At the moment**, in order to use the callback primitives one must wrap any -code that uses them with an :func:`outfeed_receiver` (an error is -raised otherwise):: +The data from the devices is received by separate threads managed by the JAX +runtime (one thread per device). The runtime maintains a buffer of +configurable size. When the buffer is full, all the receiving threads are paused +which eventually pauses the computation on devices. The runtime has one +additional thread that invokes the Python user functions with the received data. +If the processing of the callbacks is slow, it may actually lead to the runtime +buffer filling up, and eventually pausing the computation on the devices +when they need to send something. For more details on the runtime mechanism see +`runtime code +`_. - with outfeed_receiver(): - id_print(x) - jax.jit(func_with_taps)(x) +Exceptions from the user-defined tap functions are logged along with their +stack traces, but the receiving threads are not stopped. -The printing and the tap functions execute in separate threads that are started -by :func:`outfeed_receiver`. There is one thread per device. This ensures -that the outfeed from a certain device will come in order. Exceptions from -the user-define tap functions are printed along with the traceback, but the -outfeed listening does not stop until the body of the -:func:`outfeed_receiver` terminates and all the outfeeds are received. -At that point, a ``TapFunctionException`` is raised if there was an exception +In order to pause the execution until all data from computations already +started on devices has arrived and has been processed, use :func:`barrier_wait`. +This will also raise :exc:`TapFunctionException` if any exception had occurred in one of the tap functions. -**We intend to implement an alternative outfeed receiving mechanism that will -not require the user to start the `outfeed_receiver`.** - The current implementation uses the outfeed mechanism provided by XLA. The mechanism itself is quite primitive in the sense that a receiver must know exactly the shape of each incoming packet, and how many packets are expected. This makes it hard to use for multiple kinds of data in the same computation, -and it is practically impossible to use it under conditionals. +and it is practically impossible to use it under conditionals or in loops +of non-constant iteration count. Furthermore, code that uses the outfeed +mechanism directly cannot be transformed by JAX. All these limitations are +addressed by the host callback functions. The tapping API introduced here +makes it easy to share the outfeed mechanism for multiple purposes, while +supporting all transformations. -The tapping API introduced here makes it easy to share the outfeed mechanism -for multiple purposes. Instead of using it directly, just use :func:`id_tap`. +**Note that after you have used the host callback functions, you cannot +use lax.outfeed directly**. You may want to :func:`stop_outfeed_receiver` +if you later need to use lax.outfeed. We describe the behaviour under transformations in the context of the following function definition:: @@ -71,10 +81,12 @@ following function definition:: _, y = id_print(x, y, what="x,x^2") return y * x -During JAX transformations the special parameters ``transforms`` is extended -with a dictionary, containing the key ``name`` holding the name of the -transformation and additional keys holding transformation parameters, if -applicable. +During JAX transformations the special parameter ``transforms`` is added to +contain a list of transformation descriptors. Each descriptor is a dictionary +containing the key ``name`` holding the name of the transformation and +additional keys holding transformation parameters, if applicable. This +parameter is passed to the tap function (or printed), in addition to +user-defined parameters. For :func:`jax.vmap` the arguments are batched, and ``transforms`` is extended with transformation name ``batch`` and ``batch_dims`` set to the the tuple of @@ -82,7 +94,8 @@ batched dimensions (one entry per argument, ``None`` denotes an argument that was broadcast):: jax.vmap(power3)(np.arange(3.)) - # what=x,x^2 transforms=({name=batch, batch_dims=(0, 0)}): ([0, 1, 2], [0, 1, 4]) + # what=x,x^2 transforms=({name=batch, batch_dims=(0, 0)}): ([0, 1, 2], [0, 1, + 4]) For :func:`jax.jvp` there will be two callbacks, one with the values of the primals and one with the tangents:: @@ -97,8 +110,10 @@ the values of the primals from the forward pass, if those values are needed for the backward pass:: jax.grad(power3)(3.) - # what=x,x^2: (3., 9.) # from forward pass, since y is needed in backward pass - # what=x,x^2 transforms=({name=jvp}, {name=transpose}): (0., 3.) # from backward pass, adjoints of _, y + # what=x,x^2: (3., 9.) # from forward pass, since y is needed in backward + pass + # what=x,x^2 transforms=({name=jvp}, {name=transpose}): (0., 3.) # from + backward pass, adjoints of _, y See documentation for :func:`id_tap` and :func:`id_print`. For usage example, see tests/host_callback_test.py. @@ -107,23 +122,20 @@ Still to do: * Performance tests. * Add flags for logging. * Add unit tests with mocks. - * Improve the ergonomics of starting the consumer loop. Currently, when - invoking jit-ed code, one must start a consumer loop. This is not needed - when invoking code that does not involve jit. There is an error when - attempting to start a compiled computation without starting the outfeed - receiver. Perhaps we can put the receiver threads in the runtime. * Explore a simpler API that uses Python program-order, instead of data dependency-order. * Explore implementation with outside compilation. + * Explore an extended API that allows the host function to return + values to the accelerator computation. """ -from concurrent import futures -from contextlib import contextmanager +from absl import logging +import atexit +import contextlib import itertools from jax import api from jax import core -from jax import dtypes from jax import lax from jax.lib import pytree from jax.interpreters import ad, xla, batching, masking @@ -132,54 +144,43 @@ from jax import pprint_util as ppu from jax import source_info_util from jax import util from jaxlib import xla_client -from jaxlib import version as jaxlib_version +from jaxlib import xla_extension + -import logging -import msgpack # type: ignore import numpy as np +import threading import traceback -from typing import (Any, Callable, Dict, List, Optional, NamedTuple, - Sequence, Tuple, cast) +from typing import (Any, Callable, Dict, List, Optional, NamedTuple, Sequence, + Tuple, cast) +import warnings xops = xla_client._xla.ops # TODO(necula): fix mypy errors if I define the type aliases below XlaOp = Any # xla_extension.XlaOp -XlaShape = Any # xla_client.Shape +XlaShape = Any # xla_client.Shape XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder XlaDevice = Any # xla_client.Device - -# TODO: add a flag -_LOGGING = True - -# Starting with 0.1.46 the outfeed is now on the Device -# TODO(necula): remove once we fix XLA -_jaxlib_version = tuple(int(x)for x in jaxlib_version.__version__.split('.')) -if _jaxlib_version == (0, 1, 45): - _OUTFEED_MECHANISM = "xla_client" -elif _jaxlib_version >= (0, 1, 47): - _OUTFEED_MECHANISM = "device" -else: - _OUTFEED_MECHANISM = "none" +XlaLocalClient = Any # xla_extension.LocalClient -def id_tap(func: Callable, arg, *, - result=None, - **kwargs): - """Host-callback tap primitive, like identity function with a call to ``func``. +def id_tap(tap_func: Callable, arg, *, result=None, **kwargs): + """Host-callback tap primitive, like identity function with a call to ``tap_func``. **Experimental: please give feedback, and expect changes!** - ``id_tap`` behaves semantically like the identity function but has the side-effect + ``id_tap`` behaves semantically like the identity function but has the + side-effect that a user-defined Python function is called with the runtime values of the argument. Args: + * tap_func: the tap function to call. * arg: the argument passed to the tap function, can be a pytree of JAX types. * result: if given, specifies the return value of ``id_tap``. This value is - not passed to the tap function, and in fact is not sent from the device - to the host. If the ``result`` parameter is not specified then the return + not passed to the tap function, and in fact is not sent from the device to + the host. If the ``result`` parameter is not specified then the return value of ``id_tap`` is ``arg``. * kwargs: will be passed directly to the tap function. Can be anything that is hashable, these are kept in the host Python process until outfeeds are @@ -204,34 +205,37 @@ def id_tap(func: Callable, arg, *, :func:`outfeed_receiver`. For more details see the - `module documentation `_. + `module documentation + `_. """ - if _OUTFEED_MECHANISM == "none": - raise NotImplementedError("id_tap works only with jaxlib 0.1.47 and higher") - if func not in (_end_consumer, _unknown_testing_consumer): - api._check_callable(func) + _initialize_outfeed_receiver() # Lazy initialization + api._check_callable(tap_func) flat_args, arg_treedef = pytree.flatten(arg) - for arg in flat_args: api._check_arg(arg) + for arg in flat_args: + api._check_arg(arg) params = dict(kwargs) # we pass a copy of params to the primitive # See definition of id_tap_p for what parameters it takes - params["func"] = func - params["arg_treedef"] = arg_treedef + params["tap_func_"] = tap_func + params["arg_treedef_"] = arg_treedef + params["nr_tapped_args_"] = len(flat_args) if result is not None: flat_results, result_treedef = pytree.flatten(result) - for result in flat_results: api._check_arg(result) + for result in flat_results: + api._check_arg(result) all_args = flat_args + flat_results - params["nr_untapped"] = len(flat_results) + nr_results = len(flat_results) else: all_args = flat_args + nr_results = 0 flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args if result is not None: - return result_treedef.unflatten(flat_outs[-params["nr_untapped"]:]) # type: ignore[unsupported-operands] + flat_results = flat_outs[-nr_results:] # type: ignore[unsupported-operands] + return result_treedef.unflatten(flat_results) else: return arg_treedef.unflatten(flat_outs) -def id_print(arg, *, result=None, output_stream=None, threshold=None, - **kwargs): +def id_print(arg, *, result=None, output_stream=None, threshold=None, **kwargs): """Like :func:`id_tap` with a printing tap function. **Experimental: please give feedback, and expect changes!** @@ -245,12 +249,17 @@ def id_print(arg, *, result=None, output_stream=None, threshold=None, Additional keyword arguments: * ``output_stream`` if given then it will be used instead of the - built-in ``print``. The string will be passed as ``output_stream.write(s)``. + built-in ``print``. The string will be passed as + ``output_stream.write(s)``. * ``threshold`` is passed to ``numpy.array2string``. """ - return id_tap(_print_consumer, arg, - result=result, output_stream=output_stream, - threshold=threshold, **kwargs) + return id_tap( + _print_consumer, + arg, + result=result, + output_stream=output_stream, + threshold=threshold, + **kwargs) # A registry of outfeed consumers, used upon receiving outfeeds @@ -259,6 +268,7 @@ class _ConsumerCallable(NamedTuple): func: Callable kwargs: Tuple[Tuple[str, Any], ...] arg_treedef: Any + arg_shape: XlaShape # XlaShape implements __hash__. def unpack_kwargs(self): kwargs = dict(self.kwargs) @@ -266,6 +276,7 @@ class _ConsumerCallable(NamedTuple): if transforms is None: return kwargs else: + def unpack_transform(name, *params): if name == "batch": return dict(name=name, batch_dims=params[0]) @@ -274,59 +285,59 @@ class _ConsumerCallable(NamedTuple): else: assert not params, f"{name}, {params}" return dict(name=name) - return dict(kwargs, - transforms=tuple([unpack_transform(*t) for t in transforms])) -_consumer_registry: Dict[_ConsumerCallable, int] = dict() -_consumer_registry_by_id: Dict[int, _ConsumerCallable] = dict() + return dict( + kwargs, transforms=tuple([unpack_transform(*t) for t in transforms])) def _register_consumer(cons: _ConsumerCallable) -> int: """Registers a tap function, cache by hash of cons.""" - cons_id = _consumer_registry.get(cons) + cons_id = _outfeed_receiver.consumer_registry.get(cons) if cons_id is not None: return cons_id - cons_id = hash(cons) - _consumer_registry[cons] = cons_id - _consumer_registry_by_id[cons_id] = cons + cons_id = hash(cons) & 0xFFFFFFFC # pybind11 has trouble here with large ints + cons_id += 1 # Reserve the consumer ID 0 + assert cons_id not in _outfeed_receiver.consumer_registry, ( + "consumer id collision") + _outfeed_receiver.consumer_registry[cons] = cons_id + _outfeed_receiver.consumer_registry_by_id[cons_id] = cons return cons_id -def _print_consumer(arg, *, output_stream=None, - threshold=1024, **kwargs): + +def _print_consumer(arg, *, output_stream=None, threshold=1024, **kwargs): """The consumer for id_print. We provide this as a simple tapping function for printing. - This is **experimental** - and may not want to add many features to it; it should be easy for the user - to roll their own printing function. + This is **experimental** and may not want to add many features to it; + it should be easy for the user to roll their own printing function. Args: - output_stream: a function whose `write` method is called with the strings - to be output. + output_stream: a function whose `write` method is called with the strings to + be output. threshold: the value of numpy.array2string threshold parameter. **kwargs: all other keyword args are printed before printing `arg`. """ + def emit_str(s: str): if output_stream is not None: output_stream.write(s + "\n") else: print(s) - kv_pairs = " ".join([f"{k}: {v}" - for k, v in sorted(kwargs.items()) - if k not in ("consumer_id", "nr_untapped")]) + + kv_pairs = " ".join([ + f"{k}: {v}" for k, v in sorted(kwargs.items()) + ]) if kv_pairs: emit_str(kv_pairs) def pp_val(arg) -> ppu.PrettyPrint: if isinstance(arg, (tuple, list)): - return (ppu.pp('[ ') >> - ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(' ]')) + return ( + ppu.pp("[ ") >> ppu.vcat([pp_val(e) for e in arg]) >> ppu.pp(" ]")) elif isinstance(arg, dict): - return (ppu.pp('{ ') >> - ppu.vcat( - [ppu.pp(f"{k}=") >> pp_val(v) - for k, v in sorted(arg.items())]) >> - ppu.pp(' }')) + return (ppu.pp("{ ") >> ppu.vcat([ + ppu.pp(f"{k}=") >> pp_val(v) for k, v in sorted(arg.items()) + ]) >> ppu.pp(" }")) elif isinstance(arg, np.ndarray): return ppu.pp(np.array2string(arg, threshold=threshold)) else: @@ -335,36 +346,49 @@ def _print_consumer(arg, *, output_stream=None, emit_str(str(pp_val(arg))) -"""The id_tap primitive acts like the identity function. It has a number of -positional arguments and parameters: - * func: the actual (Python) function to invoke with the positional arguments - and the parameters. - * nr_untapped: how many positional arguments (from the tail) should not be - passed to the tap function. - * arg_treedef: the treedef of the tapped positional arguments. +"""The id_tap_p primitive acts like the identity function. + +It has a number of positional arguments. The result of the primitive are +the positional arguments. + +The primitive has the following parameters: + * has_token_: a boolean, when True it means that the last positional argument + is the current token. In this case, the result of the primitive is + going to be the non-token positional arguments, along with the updated + token. The tokens and this parameter are added after all the JAX + transformations, just before staging XLA. + * nr_tapped_args_: how many positional arguments from the head should be + passed to the tap function. The remaining positional arguments are there + for data dependency, for implementing the "result" feature, and for + the current token. + * tapped_args_treedef_: the treedef of the tapped positional arguments. + * tap_func_: the actual (Python) function to invoke with the tapped positional + arguments (unflatted according to tapped_args_treedef_) and + the parameters that were passed to the id_tap function. * transforms: a tuple of the transformations that have been applied. Each element of the tuple is itself a tuple with the first element the name of the transform. The remaining elements depend on the transform. For example, for `batch`, the parameters are the dimensions that have been batched, and for `mask` the logical shapes. These are unpacked by _ConsumerCallable before passing to the user function. - - * the remaining parameters are passed to the tap function. + * the remaining parameters are from the user's invocation of the id_tap + API function and are passed to the tap function. """ id_tap_p = core.Primitive("id_tap") id_tap_p.multiple_results = True xla.outfeed_primitives.add(id_tap_p) -def _add_transform(params: Dict, name: str, - *transform_params) -> Dict: + +def _add_transform(params: Dict, name: str, *transform_params) -> Dict: """Adds the `transform` to the params["transforms"]. Uses a tuple representation internally, will be unpacked before the callback by _ConsumerCallable. """ new_transform = (name, *transform_params) - return dict(params, transforms=(params.get("transforms", ()) - + (new_transform,))) + return dict( + params, transforms=(params.get("transforms", ()) + (new_transform,))) + def _id_tap_impl(*arrays, **params): # We use the jitted-version of the primitive even for eager execution, both @@ -376,8 +400,10 @@ def _id_tap_impl(*arrays, **params): # different threads. return xla.apply_primitive(id_tap_p, *arrays, **params) + id_tap_p.def_impl(_id_tap_impl) + def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \ -> Sequence[pe.AbstractValue]: return args_a @@ -385,6 +411,7 @@ def _id_tap_abstract_eval(*args_a: pe.AbstractValue, **params) \ id_tap_p.def_abstract_eval(_id_tap_abstract_eval) + # TODO(necula): there must be a better way to do this. # The AttributeError is for regular values, the KeyError is for ConcreteArray def _instantiate_zeros(arg, tan): @@ -400,27 +427,32 @@ def _instantiate_zeros(arg, tan): return ad.zeros_like_jaxval(arg) -def _id_tap_jvp_rule(primals, tangents, *, func, nr_untapped=0, **params): +def _id_tap_jvp_rule(primals, tangents, **params): # Put primals through id_tap separately, so that partial evaluation # can do its job when they are known (for grad) - out_primals = id_tap_p.bind(*primals, func=func, nr_untapped=nr_untapped, **params) + out_primals = id_tap_p.bind( + *primals, **params) # Add one primal output as untapped, to create data dependency. tangent_zeros = tuple(map(_instantiate_zeros, primals, tangents)) - out_tangents_extra = id_tap_p.bind(*tangent_zeros, out_primals[0], - func=func, nr_untapped=nr_untapped + 1, - **_add_transform(params, "jvp")) + out_tangents_extra = id_tap_p.bind( + *tangent_zeros, + out_primals[0], + **_add_transform(params, "jvp")) return tuple(out_primals), tuple(out_tangents_extra[:-1]) + ad.primitive_jvps[id_tap_p] = _id_tap_jvp_rule -def _id_tap_transpose_rule(cts, *args, func=None, nr_untapped=0, **params): +def _id_tap_transpose_rule(cts, *args, **params): assert len(cts) == len(args) cts_zeros = tuple(map(_instantiate_zeros, args, cts)) - ct_args = id_tap_p.bind(*cts_zeros, func=func, nr_untapped=nr_untapped, - **_add_transform(params, "transpose")) + ct_args = id_tap_p.bind( + *cts_zeros, + **_add_transform(params, "transpose")) return ct_args + ad.primitive_transposes[id_tap_p] = _id_tap_transpose_rule @@ -429,51 +461,51 @@ def _id_tap_batching_rule(batched_args, batch_dims, **params): res = id_tap_p.bind(*batched_args, **new_params) return res, batch_dims + batching.primitive_batchers[id_tap_p] = _id_tap_batching_rule # def _id_tap_shape_rule(*operands, **params): # return tuple([op.shape for op in operands]) -# TODO(necula): these disappeared -# masking.shape_rules[id_tap_p] = _id_tap_shape_rule # type: ignore[module-attr] def _id_tap_masking_rule(operands, operands_logical_shapes, **params): - new_params = _add_transform(params, "mask", - operands_logical_shapes) + new_params = _add_transform(params, "mask", operands_logical_shapes) return id_tap_p.bind(*operands, **new_params) + masking.masking_rules[id_tap_p] = _id_tap_masking_rule #### #### XLA compilation #### #### -# Special consumer to mark the end of outfeed stream for a device -_end_consumer = 0 -_unknown_testing_consumer = 1 # for testing error cases -def _id_tap_translation_rule_outfeed(comp: XlaComputationBuilder, - *args_op: XlaOp, func=None, - nr_untapped=0, arg_treedef=None, - **params): - params = dict(params) - if func in (_end_consumer, _unknown_testing_consumer): - params["consumer_id"] = func - else: - params["consumer_id"] = _register_consumer( - _ConsumerCallable(func, tuple(sorted(params.items())), arg_treedef)) +def _id_tap_translation_rule(comp: XlaComputationBuilder, + *args_op: XlaOp, + tap_func_=None, + nr_tapped_args_, + arg_treedef_=None, + has_token_=False, + **params): # We expect the current token at the end, inserted by _rewrite_jaxpr. + assert has_token_ current_token = args_op[-1] - assert not comp.get_shape(current_token).is_array(), "The last argument must be a token" + assert not comp.get_shape(current_token).is_array(), ( + "The last argument must be a token") - nr_args_to_emit = len(args_op) - nr_untapped - 1 - next_token = _emit_outfeed(comp, current_token, - args_op[0:nr_args_to_emit], params["consumer_id"]) + args_to_outfeed = args_op[0:nr_tapped_args_] + consumer_id = _register_consumer( + _ConsumerCallable(tap_func_, tuple(sorted(params.items())), arg_treedef_, + comp.get_shape(xops.Tuple(comp, args_to_outfeed)))) + next_token = _outfeed_receiver.receiver.add_outfeed(comp, current_token, + consumer_id, + args_to_outfeed) results = (args_op[:-1] + (next_token,)) return xops.Tuple(comp, results) -xla.translations[id_tap_p] = _id_tap_translation_rule_outfeed + +xla.translations[id_tap_p] = _id_tap_translation_rule #### #### Jaxpr rewriting logic to thread the tokens through stateful primitives. @@ -486,25 +518,22 @@ def _mk_typed_jaxpr(jaxpr: core.Jaxpr, literals: Sequence) -> core.TypedJaxpr: tuple(map(lambda v: v.aval, jaxpr.invars)), tuple(map(lambda v: v.aval, jaxpr.outvars))) -def _rewrite_typed_jaxpr(tjaxpr: core.TypedJaxpr, - has_input_token: bool, - has_output_token: bool) -> Tuple[core.TypedJaxpr, bool]: - """Rewrites a TypedJaxpr to thread the token, if needed. - Returns the rewritten Jaxpr, and whether it uses outfeed.""" - new_jaxpr, uses_outfeed = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token, - has_output_token) - return _mk_typed_jaxpr(new_jaxpr, tjaxpr.literals), uses_outfeed +def _rewrite_typed_jaxpr( + tjaxpr: core.TypedJaxpr, has_input_token: bool, + has_output_token: bool) -> core.TypedJaxpr: + """Rewrites a TypedJaxpr to thread the token, if needed.""" + new_jaxpr = _rewrite_jaxpr(tjaxpr.jaxpr, has_input_token, has_output_token) + return _mk_typed_jaxpr(new_jaxpr, tjaxpr.literals) -def _rewrite_jaxpr(jaxpr: core.Jaxpr, - has_input_token: bool, - has_output_token: bool) -> Tuple[core.Jaxpr, bool]: +def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, + has_output_token: bool) -> core.Jaxpr: """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr): - return (jaxpr, False) + return jaxpr mk_new_var = core.gensym([jaxpr]) @@ -514,9 +543,9 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, invars = jaxpr.invars + [last_token_var] else: invars = jaxpr.invars - eqns.append(core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], - lax.create_token_p, {}, - source_info_util.current())) + eqns.append( + core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], + lax.create_token_p, {}, source_info_util.current())) for eqn in jaxpr.eqns: if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params): @@ -528,130 +557,148 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, outvars = jaxpr.outvars + ([last_token_var] if has_output_token else []) new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns) - return (new_jaxpr, True) + return new_jaxpr -def _rewrite_eqn(eqn: core.JaxprEqn, - eqns: List[core.JaxprEqn], - input_token_var: core.Var, - output_token_var: core.Var, + +def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], + input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable[[core.AbstractValue], core.Var]): - """Rewrite an `eqn` and append equations to `eqns`. Assume that the - current token is in `input_token_var` and the resulting token must end in - `output_token_var`.""" + """Rewrite an `eqn` and append equations to `eqns`. + + Assume that the current token is in `input_token_var` and the resulting + token must end in `output_token_var`. + """ if eqn.primitive is id_tap_p: - eqns.append(core.new_jaxpr_eqn(eqn.invars + [input_token_var], - eqn.outvars + [output_token_var], - eqn.primitive, eqn.params, eqn.source_info)) + assert "has_token_" not in eqn.params + eqns.append( + core.new_jaxpr_eqn(eqn.invars + [input_token_var], + eqn.outvars + [output_token_var], eqn.primitive, + dict(eqn.params, has_token_=True), + eqn.source_info)) elif eqn.primitive is lax.while_p: - cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( - eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) + cond_jaxpr, _, body_jaxpr, _ = util.split_dict( + eqn.params, + ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, mk_new_var) return - eqns.append(core.new_jaxpr_eqn( - eqn.invars + [input_token_var], - eqn.outvars + [output_token_var], - eqn.primitive, - dict(eqn.params, - body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0], - cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0]), - eqn.source_info)) + eqns.append( + core.new_jaxpr_eqn( + eqn.invars + [input_token_var], eqn.outvars + [output_token_var], + eqn.primitive, + dict( + eqn.params, + body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True), + cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, + False)), eqn.source_info)) elif eqn.primitive is lax.cond_p: branches, linear = util.split_dict(eqn.params, ["branches", "linear"]) index, *operands = eqn.invars new_invars = [index, *operands, input_token_var] - eqns.append(core.new_jaxpr_eqn( - new_invars, eqn.outvars + [output_token_var], - eqn.primitive, - dict(eqn.params, - branches=tuple( - _rewrite_typed_jaxpr(jaxpr, True, True)[0] - for jaxpr in branches), - linear=(*linear, False)), - eqn.source_info)) + eqns.append( + core.new_jaxpr_eqn( + new_invars, eqn.outvars + [output_token_var], eqn.primitive, + dict( + eqn.params, + branches=tuple( + _rewrite_typed_jaxpr(jaxpr, True, True) + for jaxpr in branches), + linear=(*linear, False)), eqn.source_info)) elif eqn.primitive is lax.scan_p: num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict( - eqn.params, ["num_consts", "num_carry", "jaxpr", "linear", - "reverse", "length"]) + eqn.params, + ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length"]) # We add the token right at the end of carry nr_const_and_carry = num_consts + num_carry - new_invars = eqn.invars[0:nr_const_and_carry] + [input_token_var] + eqn.invars[nr_const_and_carry:] - new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0] + new_invars = eqn.invars[0:nr_const_and_carry] + [ + input_token_var + ] + eqn.invars[nr_const_and_carry:] + new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True) # The rewrite has put the token at end, it has to be at end of carry new_jaxpr_invars = new_jaxpr.jaxpr.invars - new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] + - [new_jaxpr_invars[-1]] + - new_jaxpr_invars[nr_const_and_carry:-1]) + new_jaxpr_invars = ( + new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] + + new_jaxpr_invars[nr_const_and_carry:-1]) new_jaxpr.jaxpr.invars = new_jaxpr_invars new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars] new_jaxpr_outvars = new_jaxpr.jaxpr.outvars - new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] + - [new_jaxpr_outvars[-1]] + - new_jaxpr_outvars[num_carry:-1]) + new_jaxpr_outvars = ( + new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] + + new_jaxpr_outvars[num_carry:-1]) new_jaxpr.jaxpr.outvars = new_jaxpr_outvars new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars] - eqns.append(core.new_jaxpr_eqn( - new_invars, - # Output token is at the end of carry result - eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:], - eqn.primitive, - dict(eqn.params, - jaxpr=new_jaxpr, - num_carry=num_carry + 1, - linear=linear + (False,)), - eqn.source_info)) + eqns.append( + core.new_jaxpr_eqn( + new_invars, + # Output token is at the end of carry result + eqn.outvars[0:num_carry] + [output_token_var] + + eqn.outvars[num_carry:], + eqn.primitive, + dict( + eqn.params, + jaxpr=new_jaxpr, + num_carry=num_carry + 1, + linear=linear + (False,)), + eqn.source_info)) elif eqn.primitive is xla.xla_call_p: call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) - eqns.append(core.new_jaxpr_eqn( - eqn.invars + [input_token_var], - eqn.outvars + [output_token_var], - eqn.primitive, - dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0]), - eqn.source_info)) + eqns.append( + core.new_jaxpr_eqn( + eqn.invars + [input_token_var], eqn.outvars + [output_token_var], + eqn.primitive, + dict( + eqn.params, + call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, + True)), eqn.source_info)) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}") -def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, - eqns: List[core.JaxprEqn], + +def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable): """Rewrite a while whose cond has outfeed""" cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( - eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) - transformed_cond_jaxpr, _ = _rewrite_typed_jaxpr(cond_jaxpr, True, True) - carry_invars = eqn.invars[cond_nconsts+body_nconsts:] + eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) + transformed_cond_jaxpr = _rewrite_typed_jaxpr(cond_jaxpr, True, True) + carry_invars = eqn.invars[cond_nconsts + body_nconsts:] # pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token) - pred1_and_token1 = [mk_new_var(ov.aval) - for ov in transformed_cond_jaxpr.jaxpr.outvars] - eqns.append(core.new_jaxpr_eqn( - eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var], - pred1_and_token1, - xla.xla_call_p, - dict(call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_before", - donated_invars=(False,) * (cond_nconsts + len(carry_invars) + 1)), - eqn.source_info)) + pred1_and_token1 = [ + mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars + ] + eqns.append( + core.new_jaxpr_eqn( + eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var], + pred1_and_token1, xla.xla_call_p, + dict( + call_jaxpr=transformed_cond_jaxpr.jaxpr, + name="cond_before", + donated_invars=(False,) * (cond_nconsts + len(carry_invars) + 1)), + eqn.source_info)) # Make a new cond "lambda pred, carry, token: pred" new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0]) - new_cond_invars = ( - [new_cond_pred_invar] + - [mk_new_var(cv.aval) for cv in carry_invars] + - [mk_new_var(core.abstract_token)]) - new_cond_jaxpr = _mk_typed_jaxpr(core.Jaxpr([], new_cond_invars, - [new_cond_pred_invar], []), - []) - # Make a new body: "lambda cond_constvars, body_constvars, pred, carry, token: - # carry2, token2 = rewrite(BODY)(body_constvars, carry, token) - # pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2) - # (pred2, carry2, token3) - transformed_body_jaxpr, _ = _rewrite_typed_jaxpr(body_jaxpr, True, True) - new_body_invars_cond_constvars = [mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts]] - new_body_invars_body_constvars = [mk_new_var(v.aval) - for v in eqn.invars[cond_nconsts:cond_nconsts+body_nconsts]] + new_cond_invars = ([new_cond_pred_invar] + + [mk_new_var(cv.aval) for cv in carry_invars] + + [mk_new_var(core.abstract_token)]) + new_cond_jaxpr = _mk_typed_jaxpr( + core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), []) + # Make a new body: + # "lambda cond_constvars, body_constvars, pred, carry, token: + # carry2, token2 = rewrite(BODY)(body_constvars, carry, token) + # pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2) + # (pred2, carry2, token3) + transformed_body_jaxpr = _rewrite_typed_jaxpr(body_jaxpr, True, True) + new_body_invars_cond_constvars = [ + mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts] + ] + new_body_invars_body_constvars = [ + mk_new_var(v.aval) + for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts] + ] new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0]) new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars] new_body_invars_token = mk_new_var(core.abstract_token) @@ -662,153 +709,49 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, new_body_token3 = mk_new_var(core.abstract_token) new_body_eqns = [ - core.new_jaxpr_eqn( - new_body_invars_body_constvars + - new_body_invars_carry + [new_body_invars_token], - new_body_carry2 + [new_body_token2], - xla.xla_call_p, - dict(call_jaxpr=transformed_body_jaxpr.jaxpr, - name="body", - donated_invars=(False,) * (len(new_body_invars_body_constvars) + - len(new_body_invars_carry) + - 1 + len(new_body_carry2) + 1)), - eqn.source_info), - core.new_jaxpr_eqn( - new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2], - [new_body_pred2, new_body_token3], - xla.xla_call_p, - dict(call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_body", - donated_invars=(False,) * (len(new_body_invars_cond_constvars) + - len(new_body_carry2) + 1 + 2)), - eqn.source_info) + core.new_jaxpr_eqn( + new_body_invars_body_constvars + new_body_invars_carry + + [new_body_invars_token], new_body_carry2 + [new_body_token2], + xla.xla_call_p, + dict( + call_jaxpr=transformed_body_jaxpr.jaxpr, + name="body", + donated_invars=(False,) * + (len(new_body_invars_body_constvars) + + len(new_body_invars_carry) + 1 + len(new_body_carry2) + 1)), + eqn.source_info), + core.new_jaxpr_eqn( + new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2], + [new_body_pred2, new_body_token3], xla.xla_call_p, + dict( + call_jaxpr=transformed_cond_jaxpr.jaxpr, + name="cond_body", + donated_invars=(False,) * (len(new_body_invars_cond_constvars) + + len(new_body_carry2) + 1 + 2)), + eqn.source_info) ] new_body_jaxpr = _mk_typed_jaxpr( - core.Jaxpr([], - (new_body_invars_cond_constvars + new_body_invars_body_constvars + - [new_body_invars_pred] + new_body_invars_carry +[new_body_invars_token]), - ([new_body_pred2] + new_body_carry2 + [new_body_token3]), - new_body_eqns), []) + core.Jaxpr([], (new_body_invars_cond_constvars + + new_body_invars_body_constvars + [new_body_invars_pred] + + new_body_invars_carry + [new_body_invars_token]), + ([new_body_pred2] + new_body_carry2 + [new_body_token3]), + new_body_eqns), []) pred_out = mk_new_var(cond_jaxpr.out_avals[0]) - eqns.append(core.new_jaxpr_eqn( - (eqn.invars[0:cond_nconsts+body_nconsts] + [pred1_and_token1[0]] + - carry_invars + [pred1_and_token1[1]]), - ([pred_out] + eqn.outvars + [output_token_var]), - lax.while_p, - dict(cond_jaxpr=new_cond_jaxpr, cond_nconsts=0, - body_jaxpr=new_body_jaxpr, body_nconsts=cond_nconsts + body_nconsts), - eqn.source_info) - ) + eqns.append( + core.new_jaxpr_eqn( + (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] + + carry_invars + [pred1_and_token1[1]]), + ([pred_out] + eqn.outvars + [output_token_var]), lax.while_p, + dict( + cond_jaxpr=new_cond_jaxpr, + cond_nconsts=0, + body_jaxpr=new_body_jaxpr, + body_nconsts=cond_nconsts + body_nconsts), eqn.source_info)) xla.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False) -# The data on the outfeed follows a protocol that allows multiplexing the -# outfeed among multiple consumers, and communicates in-stream shape and -# type of the data. -# Each batch of array data is preceeded by a header message, of type -# uint32[_OUTFEED_HEADER_LENGTH]: -# [0]: special header value (271828) -# [1, 2]: a consumer id (64-bits, big-endian encoding as uint32[2]). The -# consumer id encodes the tap function (by hash), the -# descriptor of the arrays to be outfed, and the kwargs (a sorted tuple -# of keys and values). -# [3]: the metadata length in bytes. The metadata is a msgpack-encoded value of type: -# [ (type_code, (d0, d1, ...)), ...] # for each array, element type code -# # and the dimensions. -# padded with 0s to _OUTFEED_HEADER_LENGTH -# -# -_OUTFEED_HEADER_LENGTH = 32 # In uint32 words -_OUTFEED_HEADER_START = 271828 # [0] -# consumer_id [1, 2] -# metadata_length in bytes [3] -_OUTFEED_HEADER_METADATA_LENGTH = 4 * (_OUTFEED_HEADER_LENGTH - 4) - -_CODE_TO_DTYPE = { - 0: np.dtype(np.int8), - 1: np.dtype(np.int16), - 2: np.dtype(np.int32), - 3: np.dtype(np.int64), - 4: np.dtype(np.uint8), - 5: np.dtype(np.uint16), - 6: np.dtype(np.uint32), - 7: np.dtype(np.uint64), - 8: np.dtype(np.bool_), - 9: np.dtype(np.float16), - 10: np.dtype(np.float32), - 11: np.dtype(np.float64), - 12: np.dtype(dtypes.bfloat16), -} -_DTYPE_STR_TO_CODE = dict([(str(d), c) for c, d in _CODE_TO_DTYPE.items()]) - - -def _emit_outfeed(comp: XlaComputationBuilder, token: XlaOp, - arrays: Sequence[XlaOp], consumer_id: int) -> XlaOp: - """Emits the arrays to the outfeed for the current device.""" - arrays_shape = [comp.get_shape(a) for a in arrays] - def _array_shape_to_tuple(a_shape: XlaShape): - # (element_type_code, (d0, d1, ..., dn)) - return (_DTYPE_STR_TO_CODE[str(np.dtype(a_shape.element_type()))], - a_shape.dimensions()) - metadata = msgpack.dumps(tuple(map(_array_shape_to_tuple, arrays_shape))) - metadata_len = len(metadata) - if metadata_len > _OUTFEED_HEADER_METADATA_LENGTH: - # TODO(necula): configurable metadata length - raise ValueError("Outfeed metadata too long") - metadata += b" " * (((metadata_len + 3) // 4) * 4 - metadata_len) # pad - header = ((_OUTFEED_HEADER_START, - (consumer_id >> 32) & 0xffffffff, (consumer_id & 0xffffffff), - metadata_len) + - tuple([int.from_bytes(metadata[i:i+4], byteorder="big") - for i in range(0, _OUTFEED_HEADER_METADATA_LENGTH, 4)])) - header += (0,) * (_OUTFEED_HEADER_LENGTH - len(header)) - data = xops.ConstantLiteral(comp, np.array(header, dtype=np.uint32)) - token = xops.OutfeedWithToken(data, token, comp.get_shape(data)) - - # Now send the arrays, all at once - entire_shape = xla_client.Shape.tuple_shape(arrays_shape) - token = xops.OutfeedWithToken(xops.Tuple(comp, arrays), token, entire_shape) - return token - -def _receive_outfeed(device: XlaDevice, receiver_name: str - ) -> Tuple[int, List]: - """Receives a set of arrays on the outfeed for the specificied device. - Args: - receiver_name: a name used for debugging and logging - Returns: a tuple with the consumer_id, the arrays received, and - a kwargs dictionary that was passed to _emit_outfeed. - """ - header_shape = xla_client.Shape.array_shape(np.dtype(np.uint32), - (_OUTFEED_HEADER_LENGTH,)) - - def _get_data(data_shape: XlaShape, device: XlaDevice) -> XlaShape: - if _OUTFEED_MECHANISM == "device": - return device.transfer_from_outfeed(data_shape) - else: - return xla_client.transfer_from_outfeed(data_shape, device) - - header = _get_data(header_shape, device) - if header[0] != _OUTFEED_HEADER_START: - raise ValueError(f"Read unexpected outfeed header {header[0]} [{receiver_name}]") - logging.info(f"[{receiver_name}:{device}] Outfeed read header: {header}") - consumer_id = (header[1] << 32) + header[2] - metadata_length = header[3] - assert metadata_length <= _OUTFEED_HEADER_METADATA_LENGTH - metadatas = [int(header[i]).to_bytes(4, byteorder="big") - for i in range(4, 4 + (metadata_length + 3) // 4)] - metadata = b"".join(metadatas)[:metadata_length] - array_descriptors = msgpack.unpackb(metadata) - arrays_shape = [xla_client.Shape.array_shape(_CODE_TO_DTYPE[a_descr[0]], - a_descr[1]) - for a_descr in array_descriptors] - entire_shape = xla_client.Shape.tuple_shape(arrays_shape) - arrays = _get_data(entire_shape, device) - logging.info(f"[{receiver_name}:{device}] Outfeed read data of shape " - ",".join([f"{data.dtype}{data.shape}" for data in arrays])) - return (consumer_id, arrays) - class TapFunctionException(Exception): """Signals that some tap function had exceptions. @@ -817,98 +760,181 @@ class TapFunctionException(Exception): """ pass -_outfeed_receiver_started = False -@contextmanager -def outfeed_receiver(*, - timeout_sec=10, - backends: Optional[Sequence[str]] = None, - devices: Optional[Sequence[XlaDevice]] = None, - receiver_name=""): - # TODO: better timeout management. - """Starts receivers for the :func:`id_tap` outfeed from several devices. - The receivers will run in a threadpool. The tapped functions will be invoked - in those threads. If a tap function raises an exception, an error is - printed, but the receiving continues until the body of the context manager - terminates and all outfeeds from all devices have been received. Only then - will a :exc:`TapFunctionException` be raised. - - Args: - backends: (optional) sequence of backend names for which to listen. - Will listen to all devices on those backends. By default, listed to - all devices on all known backends. - devices: (optional) sequence of devices to listed to. At most one - of `backends` or `devices` must be given. - receiver_name: (optional) a name to use with debug logging - Usage:: - - with outfeed_receiver(): - jax.jit(func)(args) - ... - jax.pmap(another_func)(args) - - The ``outfeed_receiver`` must be started outside any jitted computation. +@contextlib.contextmanager +def outfeed_receiver(): + """Implements a barrier after a block of code. + DEPRECATED: + This function is not necessary anymore, it is here for backwards compatiblity. + At the moment it implements a ``barrier_wait`` after the body of the + context manager finishes. """ - if not devices: - backends = backends or xla_client._get_local_backends().keys() - devices = tuple(itertools.chain( - *[api.local_devices(api.host_id(backend), backend) - for backend in backends if backend != "interpreter"])) - else: - if backends: - raise ValueError("At most one of `devices` or `backends` must be given.") - executor = futures.ThreadPoolExecutor( - thread_name_prefix=f"outfeed_receiver_{receiver_name}", - max_workers=len(devices)) - - count_tap_exceptions = 0 - def device_receiver_loop(device: XlaDevice) -> XlaDevice: - """Polls the outfeed for a device in a loop.""" - nonlocal count_tap_exceptions - while (True): - consumer_id, arrays = _receive_outfeed(device, receiver_name) - if _LOGGING: - logging.info(f"[{receiver_name}:{device}] Outfeed received for consumer {consumer_id} " + - (" ".join([f"({a.dtype}{a.shape})" for a in arrays]))) - if consumer_id == _end_consumer: - assert not arrays - if _LOGGING: - logging.info(f"[{receiver_name}:{device}] Outfeed received END_OUTFEED") - return device - consumer = _consumer_registry_by_id.get(consumer_id) - if consumer is None: - logging.error("Ignoring received outfeed for unknown tap consumer") - count_tap_exceptions += 1 - continue # We need to read the entire outfeed - try: - arg = api.tree_unflatten(consumer.arg_treedef, arrays) - consumer.func(arg, **consumer.unpack_kwargs()) # type: ignore[attribute-error] - except Exception as e: - logging.error(f"Postponing exception raised in tap function: {str(e)}\n{traceback.format_exc()}") - count_tap_exceptions += 1 - # We continue for now, we need to keep reading the outfeed - - receiver_futures = [executor.submit(device_receiver_loop, d) for d in devices] - # Register a callback to raise errors if any. These exception come from - # bugs in our code, not from the tap functions. - for rf in receiver_futures: - rf.add_done_callback(lambda rf: rf.result()) - global _outfeed_receiver_started - if _outfeed_receiver_started: - raise ValueError("At most one outfeed_receiver can be running at once.") - _outfeed_receiver_started = True - xla.can_execute_outfeed_computations = True + warnings.warn( + "outfeed_receiver is unnecessary and deprecated. In the latest " + "version the outfeer receiver mechanism is started automatically. Use " + "barrier_wait if instead you want to wait for outfeeds after " + "a computation", DeprecationWarning) + _initialize_outfeed_receiver() + # We will deprecate the outfeed_receiver context manager, but for now + # we just turn it into a barrier. try: yield finally: - for d in devices: # Signal the end of printing - api.jit(lambda x: id_tap(_end_consumer, None, result=x), device=d)(0) # type: ignore[arg-type] - xla.can_execute_outfeed_computations = False - _outfeed_receiver_started = False - for f in futures.as_completed(receiver_futures, timeout=timeout_sec): - finished_device = f.result() # Throw exceptions here - if _LOGGING: - logging.info(f"[{receiver_name}:{finished_device} Outfeed receiver finished") - if count_tap_exceptions > 0: - raise TapFunctionException + # We put a barrier, which will also raise the TapFunctionException + barrier_wait() + + +# For now we keep a single outfeed receiver +class _OutfeedReceiverData: + """Keep track of the outfeed receiver data.""" + receiver: Any + lock: threading.Lock + num_tap_exceptions: int + clients: Tuple[XlaLocalClient, ...] + devices: Tuple[XlaDevice, ...] + consumer_registry: Dict[_ConsumerCallable, int] + consumer_registry_by_id: Dict[int, _ConsumerCallable] + + def __init__(self): + self.receiver = None # Initialize lazily, when first needed + self.lock = threading.Lock() + self.num_tap_exceptions = 0 + self.clients = () + self.devices = () + # The consumer registries must be live for the lifetime of the program, + # because we may have cached compilations that embed consumer ids, and we + # do not want the id reused for other shapes. + self.consumer_registry = dict() + self.consumer_registry_by_id = dict() + + def stop(self): + """Wait for all pending outfeeds and stop the receiver.""" + self.receiver = None # GC will trigger the destructor + self.clients = () + self.devices = () + # Do not clear the consumer registries. + + +_outfeed_receiver = _OutfeedReceiverData() + + +# This function is called from C++; it must not allow exceptions through. +def _outfeed_receiver_callback(device, consumer_id, arrays): + #logging.vlog( + # 2, f"Outfeed received on device {device} for consumer {consumer_id} " + + # (" ".join([f"({a.dtype}{a.shape})" for a in arrays]))) + consumer = _outfeed_receiver.consumer_registry_by_id.get(consumer_id) + assert consumer is not None, "We should have crashed in the runtime" + try: + arg = api.tree_unflatten(consumer.arg_treedef, arrays) + consumer.func(arg, + **consumer.unpack_kwargs()) # type: ignore[attribute-error] + except Exception as e: + logging.error("Postponing exception raised in tap function: %s\n%s", str(e), + traceback.format_exc()) + _outfeed_receiver.num_tap_exceptions += 1 + return + + +def _initialize_outfeed_receiver( + clients: Optional[List[XlaLocalClient]] = None, + max_callback_queue_size_bytes: int = int(256 * 1e6)): + """Creates and starts the outfeed_receiver. + + This function is called lazily only when we compile an id_tap. + + Args: + * clients: the list of clients (backends) on whose devices to listen on. + * max_callback_queue_size_bytes: an optional integer to bound the maximum + size of arrays in the callback queue. When this limit is reached the + device listener pauses. + """ + try: + outfeed_receiver_module = xla_extension.outfeed_receiver + except AttributeError: + raise NotImplementedError( + "id_tap works only with jaxlib version 0.1.51 and higher") + + with _outfeed_receiver.lock: + if _outfeed_receiver.receiver is not None: + return + + if clients is None: + # By default, all devices on all backends + clients = xla_client._get_local_backends().values() # type: ignore[protected-class] + # Drop the interpreter clients + clients = tuple([c for c in clients if c.platform != "interpreter"]) # type: ignore + devices = list(itertools.chain(*[backend.devices() for backend in clients])) + _outfeed_receiver.clients = clients # type: ignore[assignment] + _outfeed_receiver.devices = devices # type: ignore[assignment] + logging.vlog( + 2, f"Starting outfeed_receiver for {[str(d) for d in devices]}. " + f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}") + _outfeed_receiver.receiver = outfeed_receiver_module.start( + _outfeed_receiver_callback, tuple(clients), + max_callback_queue_size_bytes) + + def exit_handler(): + logging.vlog(2, "Barrier wait atexit") + barrier_wait() + + atexit.register(exit_handler) # We wait as long as we have callbacks + + +def barrier_wait(): + """Blocks the calling thread until all current outfeed is processed. + + Waits until all outfeed from computations already running on all devices + has been received and processed by the Python callbacks. Raises + TapFunctionException if there were exceptions while processing the callbacks. + + This works by enqueueing a special tap computation to all devices to which + we are listening for outfeed. Once all those tap computations are done, we + return from barrier_wait. + + Note: If any of the devices are busy and cannot accept new computations, + this will deadlock. + """ + logging.vlog(2, "barrier_wait: start") + if not _outfeed_receiver.receiver: + logging.vlog(2, "barrier_wait: receiver not started") + return + + lock = threading.Lock() + cv = threading.Condition(lock=lock) + num_at_large = len(_outfeed_receiver.devices) # Protected by lock + + def barrier_tap(dev_idx): + nonlocal num_at_large + logging.vlog( + 2, f"barrier_wait: thread {threading.current_thread()} for " + f"device {_outfeed_receiver.devices[dev_idx]} at barrier_tap") + with lock: + num_at_large -= 1 + cv.notify() + + for d_idx, d in enumerate(_outfeed_receiver.devices): + logging.vlog(2, f"barrier_wait: enqueueing barrier on device {d}") + x_on_dev = api.device_put(d_idx, device=d) + api.jit(lambda x: id_tap(barrier_tap, x), device=d)(x_on_dev) + logging.vlog(2, "barrier_wait: waiting for calblacks") + with lock: + cv.wait_for(lambda: num_at_large == 0) + logging.vlog(2, "Done barrier_wait") + if _outfeed_receiver.num_tap_exceptions > 0: + _outfeed_receiver.num_tap_exceptions = 0 + raise TapFunctionException( + "There were exceptions during id_tap processing.") + +def stop_outfeed_receiver(): + """Stops the outfeed receiver runtime. + + This waits for all outfeeds from computations already running on all devices, + and then stops the outfeed receiver runtime. The runtime will be restarted + next time you use a tap function. + + It should not be necessary to use this function, unless you want to start + using lax.outfeed directly after having used host callbacks. + """ + _outfeed_receiver.stop() diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index e89714b82..b7f430f22 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -709,7 +709,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, jaxpr, out_pvals, consts = pe.trace_to_jaxpr( dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) jaxpr.invars = jaxpr.invars[1:] # ignore dummy - jaxpr, uses_outfeed = xla.apply_outfeed_rewriter(jaxpr) + jaxpr = xla.apply_outfeed_rewriter(jaxpr) out_pvs, out_consts = unzip2(out_pvals) @@ -862,7 +862,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, num_partitions, out_parts, out_pvals, compiled.local_devices(), backend) - return partial(execute_replicated, compiled, uses_outfeed, backend, handle_args, + return partial(execute_replicated, compiled, backend, handle_args, handle_outs) multi_host_supported_collectives: Set[core.Primitive] = set() @@ -1105,9 +1105,7 @@ def partitioned_sharding_spec(num_partitions: int, replication_factors=[]) -def execute_replicated(compiled, - uses_outfeed, backend, in_handler, out_handler, *args): - xla.check_before_outfeed_execution(uses_outfeed) +def execute_replicated(compiled, backend, in_handler, out_handler, *args): input_bufs = in_handler(args) out_bufs = compiled.execute_on_local_devices(list(input_bufs)) return out_handler(out_bufs) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 93098da28..e84d13cc1 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -174,12 +174,12 @@ pytype_aval_mappings.update( # We can optionally set a Jaxpr rewriter that can be applied just before # compilation. This mechanism is used for compiling id_tap, we can # remove it once we bring the id_tap implementation into the core. -outfeed_rewriter: Optional[Callable[[core.Jaxpr], Tuple[core.Jaxpr, bool]]] = None -def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, bool]: +outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None +def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: if outfeed_rewriter is not None: return outfeed_rewriter(jaxpr) else: - return jaxpr, False + return jaxpr outfeed_primitives: Set[core.Primitive] = set() def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool: @@ -207,13 +207,6 @@ def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool: return True return False -# TODO(necula): remove this when we start the outfeed receiver automatically. -can_execute_outfeed_computations: bool = False # Set by outfeed_receiver -def check_before_outfeed_execution(uses_outfeed: bool): - if uses_outfeed and not can_execute_outfeed_computations: - raise ValueError("Attempting to execute compiled code using outfeed, " - "but outfeed_receiver is not started.") - ### op-by-op execution def arg_spec(x): @@ -607,7 +600,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar jaxpr, pvals, consts = pe.trace_to_jaxpr( fun, pvals, instantiate=False, stage_out=True, bottom=True) map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) - jaxpr, uses_outfeed = apply_outfeed_rewriter(jaxpr) + jaxpr = apply_outfeed_rewriter(jaxpr) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) @@ -667,9 +660,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar options.parameter_is_tupled_arguments = tuple_args compiled = backend.compile(built, compile_options=options) if nreps == 1: - return partial(_execute_compiled, compiled, uses_outfeed, result_handlers) + return partial(_execute_compiled, compiled, result_handlers) else: - return partial(_execute_replicated, compiled, uses_outfeed, result_handlers) + return partial(_execute_replicated, compiled, result_handlers) def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args): """Configures input/output "must" aliasing based on `donated_args`.""" @@ -767,18 +760,14 @@ def _pval_to_result_handler(device, pval): else: return aval_to_result_handler(device, pv) -def _execute_compiled(compiled: XlaExecutable, uses_outfeed: bool, - handlers, *args): - check_before_outfeed_execution(uses_outfeed) +def _execute_compiled(compiled: XlaExecutable, handlers, *args): device, = compiled.local_devices() input_bufs = [device_put(x, device) for x in args if x is not token] out_bufs = compiled.execute(input_bufs) if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs) return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)] -def _execute_replicated(compiled: XlaExecutable, uses_outfeed: bool, - handlers, *args): - check_before_outfeed_execution(uses_outfeed) +def _execute_replicated(compiled: XlaExecutable, handlers, *args): input_bufs = [ [device_put(x, device) for x in args if x is not token] for device in compiled.local_devices()] diff --git a/pytest.ini b/pytest.ini index 2376a4ffb..8153e9f56 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,9 +4,14 @@ filterwarnings = ignore:No GPU/TPU found, falling back to CPU.:UserWarning ignore:Explicitly requested dtype.*is not available.*:UserWarning ignore:jax.experimental.vectorize is deprecated.*:FutureWarning + ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning # The rest are for experimental/jax_to_tf ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning ignore:can't resolve package from __spec__ or __package__:ImportWarning ignore:Using or importing the ABCs.*:DeprecationWarning doctest_optionflags = NUMBER NORMALIZE_WHITESPACE -addopts = --doctest-glob="*.rst" +addopts = --doctest-glob="*.rst" --dist=loadfile +# --dist=loadfile ensure that all the tests in one file are sent to the same runner. This is useful +# for host_callback_test which start and then stop on teardown the C++ outfeed receiver +# runtime. If we do not stop the receiver, other tests that use outfeed are going to fail. + diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 30417c3f4..c6f08fb51 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -16,11 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from functools import partial +import functools import logging -import numpy as np import os import re +import threading +import time from typing import Callable, Sequence from unittest import SkipTest @@ -34,27 +35,30 @@ from jax import test_util as jtu from jax.config import config from jax.experimental import host_callback as hcb from jax.lib import xla_bridge - +import numpy as np config.parse_flags_with_absl() FLAGS = config.FLAGS + def skip_if_jit_not_enabled(): if os.getenv("JAX_ENABLE_JIT_PRINT", "false") == "false": raise SkipTest("print jit not enabled yet; use JAX_ENABLE_JIT_PRINT env.") + def supported_dtypes(): return sorted(jtu.supported_dtypes(), key=lambda x: np.dtype(x).name) + class _TestingOutputStream(object): """Use as `output_stream` for tests.""" def __init__(self): self._output = [] - self.testMethodName = None + self.test_method_name = None def write(self, what: str) -> None: - print(f"output_stream[{self.testMethodName}]: {what}", end="") + print(f"output_stream[{self.test_method_name}]: {what}", end="") self._output.append(what) @property @@ -80,14 +84,16 @@ def fun1(a): def fun1_equiv(a): # Numerical equivalent of fun` return (a * 2.)**2 -def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str): + +def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, + expected: str, what: str): """A variant that preprocesses the string to eliminate non-determinism in - floating point values, and several uninteresting id_tap primitive params.""" + floating point values, and several uninteresting id_tap primitive params. + """ # Sometimes we get floating points in the output; we round them def repl_floats(match_group): matched = match_group.group(0) if matched == ".": return matched - # TODO: why can't we use here np.around? x = np.around(float(matched), decimals=2) return f"{x:.2f}" what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what) @@ -98,23 +104,29 @@ def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str) def repl_func(match_group): matched = match_group.group(0) if "function _print_consumer" in matched: - return "func=_print" + return "tap_func_=_print" else: return "..." - what = re.sub(r"func=(.*)", repl_func, what) + what = re.sub(r"tap_func_=(.*)", repl_func, what) tst.assertMultiLineStrippedEqual(expected, what) + class HostCallbackTest(jtu.JaxTestCase): def setUp(self): testing_stream.reset() - testing_stream.testMethodName = self._testMethodName + testing_stream.test_method_name = self._testMethodName self.old_flags = os.getenv("XLA_FLAGS", "") def tearDown(self) -> None: if os.getenv("XLA_FLAGS") != self.old_flags: os.environ["XLA_FLAGS"] = self.old_flags xla_bridge.get_backend.cache_clear() + hcb.barrier_wait() + + @classmethod + def tearDownClass(cls): + hcb.stop_outfeed_receiver() def helper_set_devices(self, nr_devices): flags_str = os.getenv("XLA_FLAGS", "") @@ -135,8 +147,8 @@ class HostCallbackTest(jtu.JaxTestCase): # TODO: renable jaxpr golden tests when changing host_callback #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(fun1)(5.))) - with hcb.outfeed_receiver(): - self.assertAllClose((5. * 2.) ** 2, fun1(5.)) + self.assertAllClose((5. * 2.) ** 2, fun1(5.)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: a * 2 10.00 @@ -150,8 +162,9 @@ what: y * 3 return x1 + y1 #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(func2)(3.))) - with hcb.outfeed_receiver(): - self.assertEqual(3. * (2. + 3.), func2(3.)) + self.assertEqual(3. * (2. + 3.), func2(3.)) + hcb.barrier_wait() + assertMultiLineStrippedEqual(self, """ [ 6.00 9.00 ]""", testing_stream.output) @@ -162,8 +175,8 @@ what: y * 3 res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream) return res["a"] + res["b"] - with hcb.outfeed_receiver(): - self.assertEqual(3. * (2. + 3.), func2(3.)) + self.assertEqual(3. * (2. + 3.), func2(3.)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ { a=6.00 b=9.00 }""", testing_stream.output) @@ -175,8 +188,8 @@ what: y * 3 output_stream=testing_stream) return x1 - with hcb.outfeed_receiver(): - self.assertEqual(3. * 4., func2(3.)) + self.assertEqual(3. * 4., func2(3.)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ [ 6.00 9.00 ]""", testing_stream.output) @@ -194,8 +207,8 @@ what: y * 3 return x3 with self.assertRaises(hcb.TapFunctionException): - with hcb.outfeed_receiver(): - _ = func(0) + func(0) + hcb.barrier_wait() # We should have received everything before the error assertMultiLineStrippedEqual(self, """ @@ -208,11 +221,8 @@ what: x3 def test_jit_simple(self): jit_fun1 = api.jit(lambda x: 3. * hcb.id_print( 2. * x, what="here", output_stream=testing_stream)) - - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - res = jit_fun1(5.) - - self.assertAllClose(6. * 5., res) + self.assertAllClose(6. * 5., jit_fun1(5.)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: here 10.00""", testing_stream.output) @@ -224,8 +234,8 @@ what: here #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(api.jit(func))(5))) - with hcb.outfeed_receiver(): - self.assertAllClose(5, api.jit(func)(5)) + self.assertAllClose(5, api.jit(func)(5)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ 42""", testing_stream.output) testing_stream.reset() @@ -239,9 +249,9 @@ what: here api.make_jaxpr(func)(1)) logging.info("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) + self.assertEqual(2, api.jit(func)(1)) + hcb.barrier_wait() - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertEqual(2, api.jit(func)(1)) assertMultiLineStrippedEqual(self, """ where: 1 1 @@ -256,10 +266,9 @@ where: 2 x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) return x2 - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertEqual(2, api.jit(func)(1)) - self.assertEqual(11, api.jit(func)(10)) - + self.assertEqual(2, api.jit(func)(1)) + self.assertEqual(11, api.jit(func)(10)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: 1 1 @@ -280,8 +289,8 @@ where: 2 x3 = api.jit(func_nested)(x1) return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream) - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertEqual(3, api.jit(func)(1)) + self.assertEqual(3, api.jit(func)(1)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: 1 1 @@ -300,9 +309,9 @@ where: 3 x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) return x2 - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - for d in devices: - self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id)) + for d in devices: + self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id)) + hcb.barrier_wait() logging.info(f"{self._testMethodName}: found output {testing_stream.output}") self.assertEqual(len(devices), len(re.findall(r"111", testing_stream.output))) self.assertEqual(len(devices), len(re.findall(r"112", testing_stream.output))) @@ -332,13 +341,12 @@ where: 3 self.assertEqual(func(5, what), a) transform = api.jit if with_jit else lambda f: f - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - for what in ("pair_1_x", "pair_x_2x", "dict"): - self.assertEqual(func(10, what), - transform(lambda x: hcb.id_tap(tap_func, func(x, what), - result=func(x * 2, what), - what=what))(5)) - # Wait for receivers to be done + for what in ("pair_1_x", "pair_x_2x", "dict"): + self.assertEqual(func(10, what), + transform(lambda x: hcb.id_tap(tap_func, func(x, what), + result=func(x * 2, what), + what=what))(5)) + hcb.barrier_wait() # Wait for receivers to be done self.assertEqual(3, tap_count) @parameterized.named_parameters( @@ -354,15 +362,17 @@ where: 3 x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) x4 = lax.cond(x % 2 == 0, - lambda x: hcb.id_print(x, where="cond_t", output_stream=testing_stream), - lambda x: hcb.id_print(-1, where="cond_f", result=x, output_stream=testing_stream), + lambda x: hcb.id_print(x, where="cond_t", + output_stream=testing_stream), + lambda x: hcb.id_print(-1, where="cond_f", result=x, + output_stream=testing_stream), x2 + 1) x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream) return x5 transform = api.jit if with_jit else lambda f: f - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertEqual(4, transform(func)(1)) + self.assertEqual(4, transform(func)(1)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: 1 1 @@ -376,9 +386,8 @@ where: end @parameterized.named_parameters( jtu.cases_from_list( - dict( - testcase_name=f"_with_jit_{with_jit}", - with_jit=with_jit) + dict(testcase_name=f"_with_jit_{with_jit}", + with_jit=with_jit) for with_jit in [True, False])) def test_while_cond(self, with_jit=False): def func(x): @@ -398,8 +407,8 @@ where: end return res transform = api.jit if with_jit else lambda f: f - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertEqual(4, transform(func)(1)) + self.assertEqual(4, transform(func)(1)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: 1 1 @@ -434,8 +443,8 @@ where: end res = hcb.id_print(x10, where="3", output_stream=testing_stream) return res - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertEqual(3, api.jit(func)(1)) + self.assertEqual(3, api.jit(func)(1)) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: w_p @@ -475,11 +484,11 @@ where: 3 res = hcb.id_print(x10, where="10", output_stream=testing_stream) return res - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - if with_jit: - func = api.jit(func) - res = func(1) - self.assertAllClose(jnp.array([1, 2, 3]), res) + if with_jit: + func = api.jit(func) + res = func(1) + self.assertAllClose(jnp.array([1, 2, 3]), res) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: 1 1 @@ -529,25 +538,23 @@ where: 10 xs, a_new_test="************", testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}")) - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - _ = jit_fun1(args) - # self.assertAllClose(args, res) + + res = jit_fun1(args) + self.assertAllClose(args, res) def test_jit_large(self): arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1)) - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - api.jit(hcb.id_print)(arg) + api.jit(hcb.id_print)(arg) def test_jit_several_together(self): arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5)) - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32)) + api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32)) def test_jit_interleaving(self): # Several jit's without data dependencies; they may interfere count = 0 # Count tap invocations nr_arrays = 5 - def tap_func(arg, **kwargs): + def tap_func(arg, **_): nonlocal count assert len(arg) == nr_arrays count += 1 @@ -556,12 +563,13 @@ where: 10 for i in range(count): x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1] return x - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - x = jnp.array(1, dtype=np.int32) - res = 0 - for i in range(10): - # No dependencies between the jit invocations - res += api.jit(lambda x: func(x, 10))(x) + + x = jnp.array(1, dtype=np.int32) + res = 0 + for _ in range(10): + # No dependencies between the jit invocations + res += api.jit(lambda x: func(x, 10))(x) + hcb.barrier_wait() self.assertEqual(100, count) def test_jit_tap_exception(self): @@ -574,9 +582,10 @@ where: 10 x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 + res = api.jit(func)(0) # No error yet with self.assertRaises(hcb.TapFunctionException): - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - res = api.jit(func)(0) + hcb.barrier_wait() + # Even though the receiver thread raised, the main thread should still # return 3. self.assertEqual(3, res) @@ -588,49 +597,20 @@ what: x3 3""", testing_stream.output) testing_stream.reset() - def test_jit_unknown_tap(self): - # Simulate an unknown tap function - def func(x): - x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) - x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err") - x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) - return x3 - - with self.assertRaises(hcb.TapFunctionException): - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - res = api.jit(func)(0) - # Even though the receiver thread raised, the main thread should still - # return 3. - self.assertEqual(3, res) - # We should have received all others - assertMultiLineStrippedEqual(self, """ -what: x1 -1 -what: x3 -3""", testing_stream.output) - testing_stream.reset() - - # On CPU and GPU the device code blocks - # On GPU it seems that there is a 5 min timeout? - # On TPU the client does not block, but messes up the rest somehow - @jtu.skip_on_devices("cpu", "gpu", "tpu") - def test_jit_receiver_ends_prematurely(self): - # Simulate an unknown tap function - def func(x): - x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) - x2 = hcb.id_tap(hcb._end_consumer, result=x1 + 1) # Will end the consumer loop - x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) - return x3 - - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - _ = api.jit(func)(0) - - assert False # It seems that the previous jit blocks above - - def test_jit_error_no_consumer(self): - # Check for errors if starting jit without a consumer active - with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"): - api.jit(lambda x: hcb.id_print(x))(0) + def test_jit_nested_cond_no_print(self): + """A nested conditional, without any prints""" + raise SkipTest("skip this") + @api.jit + def cfun(x): + return lax.cond( + lax.lt(x, 2), + lambda x: x, + lambda x: lax.cond(x < 5, + 3, lambda x: x, + 4, lambda y: y), + x) + print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text()) + cfun(1) def test_while(self): """Executing while, even without JIT uses compiled code""" @@ -641,8 +621,8 @@ what: x3 lambda c: c[1] < 5, lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1), (x, 1)) - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - func(y) + func(y) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ 1 2 @@ -650,27 +630,13 @@ what: x3 4""", testing_stream.output) testing_stream.reset() - def test_while_error_no_receiver(self): - """Executing while needs the receiver""" - y = jnp.ones(5) # captured const - def func(x): - return lax.while_loop( - lambda c: c[1] < 5, - lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1), - (x, 1)) - - with self.assertRaisesRegex(ValueError, ".*outfeed_receiver.*not started"): - func(y).block_until_ready() - - def test_jvp(self): jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,)) - #assertMultiLineStrippedEqual(self, "", - # str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1)))) - with hcb.outfeed_receiver(): - res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) + #assertMultiLineStrippedEqual(self, "") + res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) self.assertAllClose(100., res_primals, check_dtypes=False) self.assertAllClose(4., res_tangents, check_dtypes=False) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: a * 2 10.00 @@ -685,23 +651,24 @@ transforms: ({'name': 'jvp'},) what: y * 3 def test_grad_primal_unused(self): # The output of id_print is not needed for backwards pass def func(x): - return 2. * hcb.id_print(x * 3., what="x * 3", output_stream=testing_stream) + return 2. * hcb.id_print(x * 3., what="x * 3", + output_stream=testing_stream) grad_func = api.grad(func) - with hcb.outfeed_receiver(): - assertMultiLineStrippedEqual(self, """ + jaxpr = str(api.make_jaxpr(grad_func)(5.)) + # Just making the Jaxpr invokes the id_print once + hcb.barrier_wait() + assertMultiLineStrippedEqual(self, """ { lambda ; a. let - in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.))) - - # Just making the Jaxpr invokes the id_print once + in (6.00,) }""", jaxpr) assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset() - with hcb.outfeed_receiver(): - res_grad = grad_func(jnp.float32(5.)) + res_grad = grad_func(jnp.float32(5.)) + hcb.barrier_wait() self.assertAllClose(6., res_grad, check_dtypes=False) assertMultiLineStrippedEqual(self, """ @@ -714,13 +681,14 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 def test_grad_simple(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) - return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream) + return x * hcb.id_print(y * 3., what="y * 3", + output_stream=testing_stream) grad_func = api.grad(func) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.))) - with hcb.outfeed_receiver(): - res_grad = grad_func(jnp.float32(5.)) + res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: x * 2 10.00 @@ -738,18 +706,19 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 return x * (y * 3.) grad_func = api.grad(api.grad(func)) - with hcb.outfeed_receiver(): - _ = api.make_jaxpr(grad_func)(5.) - # Just making the Jaxpr invokes the id_print twiceonce - assertMultiLineStrippedEqual(self, """ + # Just making the Jaxpr invokes the id_print twice + _ = api.make_jaxpr(grad_func)(5.) + hcb.barrier_wait() + assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00""", testing_stream.output) - testing_stream.reset() - res_grad = grad_func(jnp.float32(5.)) + testing_stream.reset() + res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(12., res_grad, check_dtypes=False) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: x * 2 10.00 @@ -765,8 +734,8 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 vmap_fun1 = api.vmap(fun1) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs))) - with hcb.outfeed_receiver(): - _ = vmap_fun1(vargs) + vmap_fun1(vargs) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2 [ 8.00 10.00] @@ -784,8 +753,8 @@ transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3 vmap_func = api.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs))) - with hcb.outfeed_receiver(): - _ = vmap_func(vargs) + _ = vmap_func(vargs) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) [ 3.00 @@ -804,8 +773,8 @@ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) xv = jnp.arange(5, dtype=np.int32) yv = jnp.arange(3, dtype=np.int32) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(sum_all)(xv, yv))) - with hcb.outfeed_receiver(): - _ = sum_all(xv, yv) + _ = sum_all(xv, yv) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)}) [[0 1 2 3 4] @@ -827,9 +796,9 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dim return res inputs = np.arange(5, dtype=np.int32) - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs), - check_dtypes=False) + self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs), + check_dtypes=False) + hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1 [0 1 2 3 4] @@ -856,9 +825,9 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3 return res inputs = np.arange(5, dtype=np.int32) - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs), - check_dtypes=False) + res = api.jit(api.vmap(func))(inputs) + hcb.barrier_wait() + self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1 [0 1 2 3 4] @@ -880,21 +849,15 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3 vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32) pmap_fun1 = api.pmap(fun1, axis_name="i") - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - res = pmap_fun1(vargs) + res = pmap_fun1(vargs) + hcb.barrier_wait() expected_res = jnp.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())]) self.assertAllClose(expected_res, res, check_dtypes=False) - def test_pmap_error_no_receiver(self): - # Check for errors if starting jit without a consumer active - vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32) - with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"): - api.pmap(lambda x: hcb.id_print(x))(vargs) - def test_mask(self): # TODO(necula) raise SkipTest("masking has regressed") - @partial(api.mask, in_shapes=['n'], out_shape='') + @functools.partial(api.mask, in_shapes=['n'], out_shape='') def padded_sum(x): return jnp.sum(hcb.id_print(x, what="x", output_stream=testing_stream)) args = [jnp.arange(4)], dict(n=np.int64(2)) @@ -916,21 +879,133 @@ logical_shapes: [(2,)] transforms: ('mask',) what: x """, testing_stream.output) testing_stream.reset() + def test_outfeed_receiver(self): + """Test the deprecated outfeed_receiver""" + with hcb.outfeed_receiver(): + self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True) + assertMultiLineStrippedEqual(self, """ +what: a * 2 +10.00 +what: y * 3 +30.00""", testing_stream.output) + testing_stream.reset() + + + def test_callback_delay(self): + hcb.callback_extra = lambda dev: time.sleep(1) + + def func(x): + for i in range(5): + x = hcb.id_print(x * i, what="x times i") + return x + + api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) + + def test_callback_delay_barrier(self): + hcb.callback_extra = lambda dev: time.sleep(2) + + def func(x): + for i in range(1, 4): + x = hcb.id_print(x * i, what="x times i", output_stream=testing_stream) + return x + + api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) + # Wait for the results + hcb.barrier_wait() + expected = """ +what: x times i +[[0. 1. 2.] + [3. 4. 5.]] +what: x times i +[[ 0. 2. 4.] + [ 6. 8. 10.]] +what: x times i +[[ 0. 6. 12.] + [18. 24. 30.]]""" + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + testing_stream.reset() + # Call again + api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) + hcb.barrier_wait() + self.assertMultiLineStrippedEqual(expected, testing_stream.output) + + + def test_multiple_barriers(self): + """Call barrier_wait concurrently.""" + + def pause_tap(*args, **kwargs): + logging.info("pause_tap waiting") + time.sleep(2) + logging.info("pause_tap done") + + def long_run(x): + return hcb.id_tap(pause_tap, x) + + api.jit(long_run)(5.) + + def try_barrier(idx): + logging.info(f"Starting test barrier {idx}") + hcb.barrier_wait() + logging.info(f"Finished test barrier {idx}") + + threads = [ + threading.Thread( + name=f"barrier_{idx}", target=try_barrier, args=(idx,)) + for idx in range(3) + ] + [t.start() for t in threads] + [t.join() for t in threads] + + def test_error_bad_consumer_id(self): + """Try to use reserved consumer ID 0. + + Check that we get the proper error from the runtime.""" + comp = xla_bridge.make_computation_builder(self._testMethodName) + token = hcb.xops.CreateToken(comp) + with self.assertRaisesRegex(RuntimeError, + "Consumer ID cannot be a reserved value: 0"): + hcb._outfeed_receiver.receiver.add_outfeed( + comp, token, 0, + [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))]) + + def test_error_different_shapes(self): + """Try to register different shapes for the same consumer ID.""" + comp = xla_bridge.make_computation_builder(self._testMethodName) + token = hcb.xops.CreateToken(comp) + hcb._outfeed_receiver.receiver.add_outfeed( + comp, token, 123, + [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))]) + with self.assertRaisesRegex( + RuntimeError, ".*does not match previous shape element_type.*"): + hcb._outfeed_receiver.receiver.add_outfeed( + comp, token, 123, + [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))]) + with self.assertRaisesRegex( + RuntimeError, ".*does not match previous shape element_type.*"): + hcb._outfeed_receiver.receiver.add_outfeed( + comp, token, 123, + [xla_bridge.constant(comp, np.zeros((2,), dtype=np.float32))]) + + class OutfeedRewriterTest(jtu.JaxTestCase): + def assertRewrite(self, expected: str, func: Callable, args: Sequence, has_input_token=True, has_output_token=True): """Check that the rewrite of func(*args) matches expected.""" - _ = api.make_jaxpr(func)(*args) + jaxpr = api.make_jaxpr(func)(*args) # TODO: re-enable when we change the host_callback rewriter - #assertMultiLineStrippedEqual(self, expected, - # str(hcb._rewrite_typed_jaxpr(jaxpr, has_input_token, has_output_token)[0])) + #rewritten = hcb._rewrite_typed_jaxpr(jaxpr, + # has_input_token, has_output_token) + #assertMultiLineStrippedEqual(self, expected, str(rewritten)) + del jaxpr def test_no_outfeed(self): self.assertRewrite(""" { lambda ; a. let b = mul a a c = add a b - in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, has_output_token=False) + in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, + has_output_token=False) self.assertRewrite(""" { lambda ; a d. let b = mul a a @@ -946,9 +1021,11 @@ class OutfeedRewriterTest(jtu.JaxTestCase): self.assertRewrite(""" { lambda ; a d. let b = add a a - c e = id_tap[ arg_treedef=* - func=_print - ] b d + c e = id_tap[ arg_treedef_=* + has_token_=True + nr_tapped_args_=1 + tap_func_=_print + ] b d in (c, e) }""", lambda x: hcb.id_print(x + x), [0]) def test_cond(self): @@ -962,8 +1039,10 @@ class OutfeedRewriterTest(jtu.JaxTestCase): d = convert_element_type[ new_dtype=int32 old_dtype=bool ] c g h j = cond[ branches=( { lambda ; f_ e a b c g. - let d h = id_tap[ arg_treedef=* - func=_print + let d h = id_tap[ arg_treedef_=* + has_token_=True + nr_tapped_args_=1 + tap_func_=_print ] c g in (d, e, h) } { lambda ; d g_ a b c h. @@ -980,15 +1059,13 @@ class OutfeedRewriterTest(jtu.JaxTestCase): return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond), lambda c: (ct_body, hcb.id_print(c[1]) + 1.), (x, np.float32(1.))) - # TODO: we should not need to start a receiver here!!! I believe this is - # because of the partial evaluation of while, which calls impl, which - # uses JIT. - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertRewrite(""" + self.assertRewrite(""" { lambda b c ; a f. let d e g = while[ body_jaxpr={ lambda ; c a b f. - let d g = id_tap[ arg_treedef=* - func=_print + let d g = id_tap[ arg_treedef_=* + has_token_=True + nr_tapped_args_=1 + tap_func_=_print ] b f e = add d 1.00 in (c, e, g) } @@ -1011,43 +1088,47 @@ class OutfeedRewriterTest(jtu.JaxTestCase): lambda c: (ct_body, hcb.id_print(c[1]) + 1), (x, 1)) - # TODO: we should not need to start a receiver here!!! I believe this is - # because of the partial evaluation of while, which calls impl, which - # uses JIT. - with hcb.outfeed_receiver(receiver_name=self._testMethodName): - self.assertRewrite(""" -{ lambda b c ; a f. - let h i = xla_call[ call_jaxpr={ lambda ; c a b g. - let d e h = id_tap[ arg_treedef=* - func=_print - nr_untapped=1 - ] c b g - f = lt e 5 - in (f, h) } - name=cond_before ] b a 1 f - y d e g = while[ body_jaxpr={ lambda ; n o p q r s. - let t u v = xla_call[ call_jaxpr={ lambda ; c a b f. - let d g = id_tap[ arg_treedef=* - func=_print - ] b f - e = add d 1 - in (c, e, g) } - name=body ] o q r s - w x = xla_call[ call_jaxpr={ lambda ; c a b g. - let d e h = id_tap[ arg_treedef=* - func=_print - nr_untapped=1 - ] c b g - f = lt e 5 - in (f, h) } - name=cond_body ] n t u v - in (w, t, u, x) } - body_nconsts=2 - cond_jaxpr={ lambda ; j k l m. - let - in (j,) } - cond_nconsts=0 ] b c h a 1 i - in (d, 5, g) }""", func, [ct_body]) + self.assertRewrite(""" +{ lambda b c ; a e. + let g h = xla_call[ call_jaxpr={ lambda ; c a b f. + let _ d g = id_tap[ arg_treedef_=* + has_token_=True + nr_tapped_args_=1 + tap_func_=_print + ] c b f + e = lt d 5 + in (e, g) } + donated_invars=(False, False, False, False) + name=cond_before ] b a 1 e + x d _ f = + while[ body_jaxpr={ lambda ; m n o p q r. + let s t u = xla_call[ call_jaxpr={ lambda ; c a b f. + let d g = id_tap[ arg_treedef_=* + has_token_=True + nr_tapped_args_=1 + tap_func_=_print + ] b f + e = add d 1 + in (c, e, g) } + donated_invars=(False, False, False, False, False, False, False) + name=body ] n p q r + v w = xla_call[ call_jaxpr={ lambda ; c a b f. + let _ d g = id_tap[ arg_treedef_=* + has_token_=True + nr_tapped_args_=1 + tap_func_=_print + ] c b f + e = lt d 5 + in (e, g) } + donated_invars=(False, False, False, False, False, False) + name=cond_body ] m s t u + in (v, s, t, w) } + body_nconsts=2 + cond_jaxpr={ lambda ; i j k l. + let + in (i,) } + cond_nconsts=0 ] b c g a 1 h + in (d, 5, f) }""", func, [ct_body]) def test_scan(self): y = jnp.ones(5) # captured const @@ -1055,16 +1136,19 @@ class OutfeedRewriterTest(jtu.JaxTestCase): return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x) self.assertRewrite(""" { lambda b ; a f. - let c d g e = scan[ jaxpr={ lambda ; f a b g c. - let d e h = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*]) - func=_print - ] a b g - in (d, e, h, f) } - length=5 - linear=(False, False, False, False, False) - num_carry=3 - num_consts=1 - reverse=False ] b 1 2 f a + let c d g e = + scan[ jaxpr={ lambda ; f a b g c. + let d e h = id_tap[ arg_treedef_=PyTreeDef(tuple, [*,*]) + has_token_=True + nr_tapped_args_=2 + tap_func_=_print + ] a b g + in (d, e, h, f) } + length=5 + linear=(False, False, False, False, False) + num_carry=3 + num_consts=1 + reverse=False ] b 1 2 f a in (c, d, e, g) }""", func, [y]) diff --git a/tests/infeed_test.py b/tests/infeed_test.py index f7b8093ca..81a41e391 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -19,6 +19,7 @@ from absl.testing import absltest import jax from jax import lax, numpy as np from jax.config import config +from jax.experimental import host_callback as hcb from jax.lib import xla_client import jax.test_util as jtu import numpy as onp @@ -47,6 +48,7 @@ class InfeedTest(jtu.JaxTestCase): self.assertAllClose(f(x), x + y + z) def testInfeedThenOutfeed(self): + hcb.stop_outfeed_receiver() @jax.jit def f(x): token = lax.create_token(x) @@ -67,6 +69,7 @@ class InfeedTest(jtu.JaxTestCase): self.assertAllClose(out, y + onp.float32(1)) def testInfeedThenOutfeedInALoop(self): + hcb.stop_outfeed_receiver() def doubler(_, token): y, token = lax.infeed( token, shape=jax.ShapedArray((3, 4), np.float32))