Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.
Fix build failure with dangling matplotlib reference.
PiperOrigin-RevId: 465562141
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.
To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```
Issue #7323
PiperOrigin-RevId: 458551208
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
Currently, all XLA side-effect ops inside a sharded computation must have
explicit sharding. This includes the outfeed and infeed used by host_callback.
The implementation here uses AssignDevice sharding for both the outfeed and the
infeed. This means that before outfeed, the devices will do an all_gather and
the first device will make the outfeed. The host callback will receive a single
outfeed with the entire array, and is supposed to return the entire array. This
gets sent to the same device that issued to outfeed, which is responsible to
send the respective slices to the other participating devices.
PiperOrigin-RevId: 370711606
xmap can now handle real devices, so there's no point in maintaining the
simulated codepaths. Also, remove single-dimensional gmap as it will
have to be superseeded by a more xmap-friendly alternative.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753