* Combine the concepts of "platform" and "backend". The main upshot of this is that the tpu_driver backend requires users to write `jit(..., backend="tpu_driver")` if mixing CPU and TPU execution, however I doubt users are writing that because it didn't work to mix CPU and tpu_driver before.
* Initialize all platforms at startup, rather than lazily initializing platforms on demand. This makes it easy to do things like "list the available platforms".
* Don't use two levels of caching. Cache backends only in xla_bridge.py, not xla_client.py.
PiperOrigin-RevId: 376883261
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.
Fixes https://github.com/google/jax/issues/6671
PiperOrigin-RevId: 374683190
libdevice.10.bc is a redistributable part of the CUDA SDK.
This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
Create separate holder objects for global and thread-local state, and move enable_x64 and disable_jit context into the holder objects.
Expose the global and per-thread state objects to Python via pybind11.
Refactoring only; no functional changes intended.
PiperOrigin-RevId: 363510449
Bump jaxlib version to 0.1.61 and update changelog.
Change jaxlib numpy version limit to >=1.16 for next release. Releases older than 1.16 are deprecated per NEP 00029. Reenable NumPy 1.20.
Bump minimum jaxlib version to 0.1.60.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)
For the benchmark in #2952 on my workstation:
Before:
```
907.3490574884647
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 37.088446617126465
jax fft execution time [ms]: 74.93342399597168
```
After:
```
907.3490574884647
max: 1.9057386696477137e-12
mean: 3.9326737908882566e-13
min: 0.0
numpy fft execution time [ms]: 37.756404876708984
jax fft execution time [ms]: 28.128278255462646
```
Fixes https://github.com/google/jax/issues/2952
PiperOrigin-RevId: 338743753
* Add options to compute L/R eigenvectors in geev.
The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.
Reformulate eig-related operations based on the new geev API.
* Addressed hawkinsp's comments from google/jax#3882.
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
This starts a C++ jit codepath to speed up dispatch time.
Tracing is not supported yet.
Supported features:
- scalar, numpy array and DeviceArray argument support:
- integer, floats, boolean, and complex scalars arguments are supported.
- The jax_enable_x64 flag will be used at object-creation type to cast scalars and numpy arrays.
- The Jax `weak_type` attribute for arguments is supported (DeviceArray and scalars).
- The donate_argnums argument.
- Use an XLA tuple for more than 100 arguments
Unsupported features:
- jax._cpp_jit on methods e.g
@functools.partial(jax.jit, static_argnums=0)
def _compute_log_data(self, ...)
...
This is currently not supported by the C++ codepath, because "self" won't be automatically added.
- disable_jit.
With this change, a value `x` can be replicated `nrep` times as follows:
```python
pmap(lambda _: x)(np.arange(nrep))
```
This will broadcast `x` into a ShardedDeviceArray suitable for passing into another pmap with the same input shape.
If `x` will be passed into a pmap with `devices` or a nested pmap, the replication pmap(s) should follow that structure. For example:
```python
x = pmap(pmap(lambda _: x))(np.ones(2, 4))
pmap(pmap(lambda i: i**2 + x))(np.ones(2, 4))
```
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.
When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.