This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
* Refactored host_callback to use the C++ runtime.
* The new runtime makes it unnecessary to start the outfeed_receiver
in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
I am trying to solve this by (a) making host_callback_test stop the
outfeed receiver on finish and infeed_test on start, and (b)
telling pytest-xdist to run all the tests from one file into
a single worker.
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* Remove usage of xla_client.{Computation,ComputationBuilder}.
ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.
Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.