357 Commits

Author SHA1 Message Date
Clemens Giuliani
330e69dd30 Add support for ROCm 2020-12-06 22:55:53 +01:00
Peter Hawkins
4a774978a2 Add test for chkstk_darwin symbol to jaxlib Mac builds.
We don't know why some builds produce this and others do not, but we can at least test for it to prevent bad releases.
2020-12-03 14:05:20 -05:00
Peter Hawkins
fca46ade35 Remove duplicate code to copy cusolver_kernels in wheel build.
PiperOrigin-RevId: 344840695
2020-11-30 10:32:13 -08:00
Cloud Han
ea340eed93 use strict action env on windows to avoid constant full rebuilding 2020-11-24 23:44:26 +08:00
Cloud Han
2146cea157 use shutil.copyfile to avoid readonly pyd file 2020-11-24 00:05:42 +08:00
Peter Hawkins
c06ead6b04 Change jaxlib build rules to build a wheel, rather than writing output to the source directory. 2020-11-20 11:47:00 -05:00
Peter Hawkins
1fcbd2f083 Add -Wno-stringop-truncation to build flags on Linux.
Works around https://github.com/tensorflow/tensorflow/issues/39467 for gcc 10+ builds.
2020-11-19 22:50:08 -05:00
Peter Hawkins
4bb5dca779 Fix build.py to work on Linux once again.
* strip DOS end-of-line characters from build.py for consistency with the rest of the source tree.
* use shutil.copy() instead of shutil.copyfile(). On Unix systems we must preserve execute permissions.
* add code to explicitly delete and recreate the target directory.
* Move build/jaxlib/__init_py to jaxlib/__init__.py and have the script move it into position, so the output directory for the jaxlib is an empty directory that the script creates.
2020-11-19 13:14:34 -05:00
Cloud Han
a6acce58e0 Build on Windows
1. Build on Windows

2. Fix OverflowError

    When calling `key = random.PRNGKey(0)` OverflowError: Python int too
    large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
    from python int to int32.

3. fix file path in regex of errors_test

4. handle ValueError of os.path.commonpath
2020-11-19 23:33:06 +08:00
Peter Hawkins
b222e5e05f Update bazel version to 3.1.0. 2020-11-17 15:50:30 -05:00
Skye Wanderman-Milne
5c89170dc1 jaxlib: Fix Python 3.9 build and drop CUDA 10.0 support
The TF/XLA build no longer works with CUDA 10.0. This could potentially be fixed, but isn't easy or officially supported.
2020-11-12 18:31:35 +00:00
Peter Hawkins
2c6f932e0e Add Python 3.9 support to jaxlib build. 2020-11-09 16:25:58 -05:00
Peter Hawkins
dadf732f03 Fix manylinux2010 compliance of GPU wheels.
Use GCC_HOST_COMPILER_PATH to point to the devtoolset compiler.
Use auditwheel to verify manylinux2010 compliance.
2020-11-09 14:35:45 -05:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Skye Wanderman-Milne
01e113d740 Update jaxlib build scripts to build CUDA 11.1 wheels. 2020-10-13 21:31:43 +00:00
Skye Wanderman-Milne
cacb01753a Use local version identifiers to distribute cuda jaxlib wheels.
This change:

* Updates our jaxlib build scripts to add `+cudaXXX` to the wheel
  version, where XXX is the CUDA version number (e.g. `110`). nocuda
  builds remain unchanged and do not have this extra identifier.

* Adds `generate_release_index.py`, which writes an html page that pip
  can use to find the cuda wheels. (I based this format off of
  wheel PyTorch's index).

* Updates the README to use the new local version identifier + wheel
  index.

The end result is that the command to install cuda wheels is now much
simpler.

I manually made copies of the latest jaxlib 0.1.55 wheels that have
the local version identifiers, so the new installation commands
already work (as well as the old ones, until the next jaxlib release
using the new tooling).

Fow now, I put the html index to the GCP bucket with the wheels. We
can move it to a prettier URL if/when we have one.
2020-10-09 13:47:54 -07:00
Srijan Saurav
40e20242db
Fix code quality issues (#4302)
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
2020-09-17 09:21:18 -07:00
Benjamin Chetioui
58a117fe0d
Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266)
* Add options to compute L/R eigenvectors in geev.

The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.

Reformulate eig-related operations based on the new geev API.

* Addressed hawkinsp's comments from google/jax#3882.

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-09-15 11:45:15 +03:00
Qiao Zhang
adb344880b
Reorder nocuda/cuda build to fail early. (#4243) 2020-09-09 17:23:43 -07:00
Qiao Zhang
0b04439f11
Update install_cuda script to specify cublas. (#4240) 2020-09-09 17:16:58 -07:00
Qiao Zhang
3cf7336753
Fix Dockerfile wheel installation issues. (#4232) 2020-09-08 21:28:34 -07:00
Skye Wanderman-Milne
cf2d15d4bb
jaxlib build fixes. (#4066)
1. `wheel.pep425tags` has been removed as of
   https://github.com/pypa/setuptools/pull/1829. Use the new
   `packaging.tags` instead.

2. Add `--allow-downgrades` to cuda install command. I'm not sure this
   is always necessary, but I ran into it, I'm guessing due to a cached
   docker image.
2020-09-08 18:23:42 -07:00
Peter Hawkins
8e166adcbd
Unbreak jaxlib build. (#4098) 2020-08-18 21:24:41 -04:00
Jean-Baptiste Lespiau
2ab6b42a45
Use pytree defined in tensorflow. (#4087)
It also adds some tests on the scalar C++ conversion.
2020-08-18 08:58:43 +03:00
George Necula
c7aff1da06
Revert "Use pytree from xla_client. (#4063)" (#4081)
This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f.

Tryting to revert because it seems that this produces test
failures in Google.
2020-08-17 12:53:18 +03:00
Jean-Baptiste Lespiau
d8de6b6141
Use pytree from xla_client. (#4063) 2020-08-14 11:44:03 -04:00
George Necula
0b99ca896d
[jax2tf] Disable the CI tests for jax2tf. (#4019)
We do this to see if this reduces the incidence of errors fetching
the tf-nightly package. These tests are being run when we import
the code in Google.
2020-08-11 12:39:54 +03:00
Matthew Johnson
64ec4443be another attempt at fixing github ci 2020-07-30 10:17:00 -07:00
Matthew Johnson
420e8422a4 attempt to fix CI by updating jax2tf test dep 2020-07-30 08:08:48 -07:00
Jake Vanderplas
2796032e38
Tweak Dockerfile to prevent build failure and add TODO (#3838) 2020-07-23 13:08:06 -07:00
Peter Hawkins
b943b31b22
Add jax.image.resize. (#3703)
* Add jax.image.resize.

This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.

While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
2020-07-10 09:57:59 -04:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
George Necula
166e795d63
Updated minimum jaxlib to 0.1.51 (#3642) 2020-07-02 18:53:58 +03:00
George Necula
448c635a9e
[jax2tf] Update the tf-nightly version to 2020-07-01 (#3635) 2020-07-02 10:21:23 +03:00
Peter Hawkins
8f86b139fe
Update docker script for CUDA 11. (#3571) 2020-06-26 11:20:51 -04:00
Peter Hawkins
a141cc6e8d
Make CUDA wheels manylinux2010 compliant, add CUDA 11, drop CUDA 9.2 (#3555)
* Use dynamic loading to locate CUDA libraries in jaxlib.

This should allow jaxlib CUDA wheels to be manylinux2010 compliant.

* Tag CUDA jaxlib wheels as manylinux2010.

Drop support for CUDA 9.2, add support for CUDA 11.0.

* Reorder CUDA imports.
2020-06-25 14:37:14 -04:00
Peter Hawkins
f036f5ddb0
Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation war… (#3543)
* Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation warnings.

* Pin a newer tf-nightly to fix jax2tf tests for NumPy 1.19.0
2020-06-24 15:19:00 -04:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Peter Hawkins
58f8329a40
Switch CI builds from Travis to Github actions (#3409) 2020-06-11 17:10:56 -04:00
George Necula
ddf079d8f3
Minor improvements to the script to build macos wheels (#3013) 2020-05-11 20:17:26 +03:00
Peter Hawkins
821193b2b7
Explicitly build specific CUDA capabilities. (#2722)
We choose the same set as TensorFlow (minus 3.7, which TF is apparently considering dropping anyway).

This avoids a slow PTX -> SASS compilation on first time startup.
2020-04-15 10:57:53 -04:00
Peter Hawkins
cbdf9a5a43
Drop support for Python 3.5. (#2445) 2020-03-18 10:54:28 -04:00
Peter Hawkins
219d503e71
Don't show progress bar in build script if output is not a terminal. (#2429) 2020-03-16 11:01:08 -04:00
Peter Hawkins
80abdf0c53
Unbreak build and update XLA. (#2289)
* raise minimum Bazel version to 2.0.0 to match TensorFlow.
* set --experimental_repo_remote_exec since it is required by the TF build.
* bump TF/XLA version.
* use the --config=short_logs trick from TF to suppress build warnings.
2020-02-22 09:45:24 -08:00
Peter Hawkins
b6e8341176
Improve developer documentation. (#2247)
Add Python version test to build.py.
2020-02-17 11:24:03 -08:00
Skye Wanderman-Milne
bf91ebf67a
Return error number in build.py on bad bazel version. (#2218)
This prevents our build scripts from continuing on error.
2020-02-12 09:57:54 -08:00
Skye Wanderman-Milne
96a65de6c8
Try downloading bazel before using pre-installed bazel. (#2217)
This ensures we're using the right bazel version.
2020-02-12 09:49:33 -08:00
Peter Hawkins
fe041c7590 Set minimum Bazel version to 1.2.1. 2020-02-03 10:13:51 -05:00
Ruizhe Zhao
8c7fc3919d
Upgrade bazel from 0.29.1 to 1.2.1 (#2137) 2020-02-03 10:12:40 -05:00
Skye Wanderman-Milne
409d057f76
Build CUDA 10.2 jaxlibs. (#2121)
Also adds install_cuda.sh script that sets appropriate nccl and cuDNN versions.
2020-01-29 10:47:17 -08:00