61 Commits

Author SHA1 Message Date
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
f33ce0d844 Warn if importing jaxlib on Mac ARM machines.
We can remove this warning when Mac ARM has CI testing.
2021-07-13 09:24:48 -04:00
Peter Hawkins
d658108d36 Fix type errors with current mypy and NumPy.
Enable type stubs for jaxlib.

Fix a nondeterminism problem in jax2tf tests.
2021-06-24 10:51:06 -04:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Peter Hawkins
7db0c56a22 [JAX] Change how JAX manages XLA platforms.
* Combine the concepts of "platform" and "backend". The main upshot of this is that the tpu_driver backend requires users to write `jit(..., backend="tpu_driver")` if mixing CPU and TPU execution, however I doubt users are writing that because it didn't work to mix CPU and tpu_driver before.
* Initialize all platforms at startup, rather than lazily initializing platforms on demand. This makes it easy to do things like "list the available platforms".
* Don't use two levels of caching. Cache backends only in xla_bridge.py, not xla_client.py.

PiperOrigin-RevId: 376883261
2021-06-01 11:44:31 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
Peter Hawkins
c983d3c660 Bundle libdevice.10.bc with jaxlib wheels.
libdevice.10.bc is a redistributable part of the CUDA SDK.

This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
2021-04-29 10:26:03 -04:00
Jake VanderPlas
0d4bcde7ca Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
2021-04-15 10:11:10 -07:00
Andreas Hoenselaar
a19098d462 Reimplement as JAX Primitive 2021-04-03 14:11:36 -07:00
Peter Hawkins
99daf13652 Fix flake8 error. 2021-03-24 15:28:01 -04:00
Peter Hawkins
90e28a1b8e Remove deprecated compatibility code for jaxlib < 0.1.64. 2021-03-24 14:32:23 -04:00
Jake VanderPlas
f9a4162551 Specify minimum jaxlib version in a single location 2021-03-22 16:14:41 -07:00
Peter Hawkins
23756a040b [JAX] Refactor handling of JIT interpreter state in jax_jit API.
Create separate holder objects for global and thread-local state, and move enable_x64 and disable_jit context into the holder objects.

Expose the global and per-thread state objects to Python via pybind11.

Refactoring only; no functional changes intended.

PiperOrigin-RevId: 363510449
2021-03-17 14:39:34 -07:00
Jake VanderPlas
d6408a4e6a Add extras_require to setup.py 2021-03-16 13:23:46 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Peter Hawkins
ee53eeb541 Add helpful message when import jaxlib fails. 2021-03-15 21:07:59 -04:00
Jean-Baptiste Lespiau
18343817c8 Use the C++ object for the Sharding specification. 2021-02-12 16:02:58 +01:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
Peter Hawkins
13f3819054 Update README.md for jaxlib 0.1.60.
Bump jaxlib version to 0.1.61 and update changelog.

Change jaxlib numpy version limit to >=1.16 for next release. Releases older than 1.16 are deprecated per NEP 00029. Reenable NumPy 1.20.

Bump minimum jaxlib version to 0.1.60.
2021-02-03 20:44:01 -05:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
Skye Wanderman-Milne
7c2454e969 Update jaxlib version, minimum jaxlib version, readme, and changelog.
Bumping the min jaxlib version to support https://github.com/google/jax/pull/5213.
2021-01-15 12:56:08 -08:00
Peter Hawkins
f1df104fe9 Remove some old jaxlib compatibility code. 2021-01-14 14:21:01 -05:00
jax authors
9c2aed32ef Add # pytype: disable=import-error to a couple of import statements to allow
--cpu=ppc builds (the imported modules aren't currently linked into jaxlib when building for ppc).

PiperOrigin-RevId: 351648541
2021-01-13 13:03:17 -08:00
Jean-Baptiste Lespiau
b39dbe1846 Add a version number for the XLA code targetting jaxlib.
There is likely a better long term solution, but this is an easy incremental
improvement then.

PiperOrigin-RevId: 350206440
2021-01-05 13:24:21 -08:00
jax authors
a8518769a2 Merge pull request #5115 from inailuig:rocm-gpukernels
PiperOrigin-RevId: 348077827
2020-12-17 14:01:04 -08:00
Clemens Giuliani
4981c53ac1 Add BLAS and LAPACK gpu kernels for ROCm 2020-12-16 16:00:17 +01:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Benjamin Chetioui
58a117fe0d
Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266)
* Add options to compute L/R eigenvectors in geev.

The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.

Reformulate eig-related operations based on the new geev API.

* Addressed hawkinsp's comments from google/jax#3882.

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-09-15 11:45:15 +03:00
Jean-Baptiste Lespiau
9ca1020472
Add a fast C++ jit codepath. (#4089)
This starts a C++ jit codepath to speed up dispatch time.
Tracing is not supported yet.

Supported features:
- scalar, numpy array and DeviceArray argument support:
  - integer, floats, boolean, and complex scalars arguments are supported.
  - The jax_enable_x64 flag will be used at object-creation type to cast scalars and numpy arrays.
  - The Jax `weak_type` attribute for arguments is supported (DeviceArray and scalars).
- The donate_argnums argument.
- Use an XLA tuple for more than 100 arguments

Unsupported features:
- jax._cpp_jit on methods e.g
    @functools.partial(jax.jit, static_argnums=0)
    def _compute_log_data(self, ...)
      ...
  This is currently not supported by the C++ codepath, because "self" won't be automatically added.
- disable_jit.
2020-08-19 12:39:25 -04:00
Jean-Baptiste Lespiau
2ab6b42a45
Use pytree defined in tensorflow. (#4087)
It also adds some tests on the scalar C++ conversion.
2020-08-18 08:58:43 +03:00
George Necula
c7aff1da06
Revert "Use pytree from xla_client. (#4063)" (#4081)
This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f.

Tryting to revert because it seems that this produces test
failures in Google.
2020-08-17 12:53:18 +03:00
Jake Vanderplas
8923bab50d
fixes for pytype (#4068) 2020-08-14 12:53:02 -07:00
Jean-Baptiste Lespiau
d8de6b6141
Use pytree from xla_client. (#4063) 2020-08-14 11:44:03 -04:00
George Necula
166e795d63
Updated minimum jaxlib to 0.1.51 (#3642) 2020-07-02 18:53:58 +03:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Jake Vanderplas
a63b9cc256
Cleanup: deflake interpreters, lib, nn, third_party, and tools (#3327) 2020-06-04 15:27:48 -07:00
Skye Wanderman-Milne
0d97c3ba01
Import tpu_driver after xla_client (#3064)
This is a workaround until we build a new jaxlib with f462867806
2020-05-12 11:05:03 -07:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
Matthew Johnson
6239e59415 bump min jaxlib version (thanks @hawkinsp) 2020-04-21 19:01:19 -07:00
Peter Hawkins
bfbd0b800f
Move tuple_arguments onto Compile() instead of Execute(). (#2559)
Update minimum jaxlib version to 0.1.43.
2020-03-31 17:09:14 -04:00
Peter Hawkins
ec9513fa29
Advertise jaxlib 0.1.41. (#2432)
Bump minimum jaxlib version to 0.1.41.
2020-03-16 16:10:26 -04:00
Ram Rachum
f3f0abb53e
Fix exception causes all over the codebase (#2376)
Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-03-09 16:06:12 -04:00
Peter Hawkins
1006758b94
Bump minimum jaxlib version to 0.1.40. (#2360) 2020-03-05 09:25:16 -05:00
Peter Hawkins
991324f8df
Increase minimum jaxlib version to 0.1.38. (#2120) 2020-01-29 14:16:58 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Skye Wanderman-Milne
b459b6098a
Make pmap properly replicate closed-over constants. (#1847)
With this change, a value `x` can be replicated `nrep` times as follows:

```python
pmap(lambda _: x)(np.arange(nrep))
```

This will broadcast `x` into a ShardedDeviceArray suitable for passing into another pmap with the same input shape.

If `x` will be passed into a pmap with `devices` or a nested pmap, the replication pmap(s) should follow that structure. For example:

```python
x = pmap(pmap(lambda _: x))(np.ones(2, 4))

pmap(pmap(lambda i: i**2 + x))(np.ones(2, 4))
```
2019-12-17 16:22:55 -08:00
George Necula
227a91220b Update minimum jaxlib to 0.1.36.
This is needed in part to pull in new Device.platform from Tensorflow.
See #1764.
2019-11-26 08:20:19 +01:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
android
3f0c1cd9dd Add TPU Driver as JAX backend for high-performance access to Google Cloud TPU hardware. (#1675) 2019-11-14 14:00:08 -08:00
Peter Hawkins
ce5b8670f3
Delete XRT references from jax. (#1588) 2019-10-29 11:26:48 -04:00