The semantics of mentioning `device` or `backend` on `pjit` is the same as doing a `device_put` i.e. no matter which device the arg is on, reshard it to the device mentioned.
PiperOrigin-RevId: 495437165
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).
By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).
I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.
PiperOrigin-RevId: 404405186
We often forget to put the per-test-case decorators, resulting in test
failures in cases not covered by github CI (e.g. Cloud TPU
tests). This change filters the "experimental feature" warnings by
default.
* Make infeed_test and host_callback_test independent.
* the infeed_test will stop the outfeed receiver
* Remove the use of --dist=loadfile.
* Prevent logging on exit
* Refactored host_callback to use the C++ runtime.
* The new runtime makes it unnecessary to start the outfeed_receiver
in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
I am trying to solve this by (a) making host_callback_test stop the
outfeed receiver on finish and infeed_test on start, and (b)
telling pytest-xdist to run all the tests from one file into
a single worker.
* Initial import of jax2tf into JAX core
Renamed jax2tf.convert to jax_to_tf.
Added Travis test support.
Added OSS build configuration.
* Added support for squeeze
* Make pytest run over JAX tests warning clean, and error on warnings.
Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.
Also fix crashes on Mac related to a preexisting linear algebra bug.
* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.