135 Commits

Author SHA1 Message Date
Peter Hawkins
8e166adcbd
Unbreak jaxlib build. (#4098) 2020-08-18 21:24:41 -04:00
Jean-Baptiste Lespiau
2ab6b42a45
Use pytree defined in tensorflow. (#4087)
It also adds some tests on the scalar C++ conversion.
2020-08-18 08:58:43 +03:00
George Necula
c7aff1da06
Revert "Use pytree from xla_client. (#4063)" (#4081)
This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f.

Tryting to revert because it seems that this produces test
failures in Google.
2020-08-17 12:53:18 +03:00
Jean-Baptiste Lespiau
d8de6b6141
Use pytree from xla_client. (#4063) 2020-08-14 11:44:03 -04:00
George Necula
0b99ca896d
[jax2tf] Disable the CI tests for jax2tf. (#4019)
We do this to see if this reduces the incidence of errors fetching
the tf-nightly package. These tests are being run when we import
the code in Google.
2020-08-11 12:39:54 +03:00
Matthew Johnson
64ec4443be another attempt at fixing github ci 2020-07-30 10:17:00 -07:00
Matthew Johnson
420e8422a4 attempt to fix CI by updating jax2tf test dep 2020-07-30 08:08:48 -07:00
Jake Vanderplas
2796032e38
Tweak Dockerfile to prevent build failure and add TODO (#3838) 2020-07-23 13:08:06 -07:00
Peter Hawkins
b943b31b22
Add jax.image.resize. (#3703)
* Add jax.image.resize.

This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.

While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
2020-07-10 09:57:59 -04:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
George Necula
166e795d63
Updated minimum jaxlib to 0.1.51 (#3642) 2020-07-02 18:53:58 +03:00
George Necula
448c635a9e
[jax2tf] Update the tf-nightly version to 2020-07-01 (#3635) 2020-07-02 10:21:23 +03:00
Peter Hawkins
8f86b139fe
Update docker script for CUDA 11. (#3571) 2020-06-26 11:20:51 -04:00
Peter Hawkins
a141cc6e8d
Make CUDA wheels manylinux2010 compliant, add CUDA 11, drop CUDA 9.2 (#3555)
* Use dynamic loading to locate CUDA libraries in jaxlib.

This should allow jaxlib CUDA wheels to be manylinux2010 compliant.

* Tag CUDA jaxlib wheels as manylinux2010.

Drop support for CUDA 9.2, add support for CUDA 11.0.

* Reorder CUDA imports.
2020-06-25 14:37:14 -04:00
Peter Hawkins
f036f5ddb0
Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation war… (#3543)
* Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation warnings.

* Pin a newer tf-nightly to fix jax2tf tests for NumPy 1.19.0
2020-06-24 15:19:00 -04:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Peter Hawkins
58f8329a40
Switch CI builds from Travis to Github actions (#3409) 2020-06-11 17:10:56 -04:00
George Necula
ddf079d8f3
Minor improvements to the script to build macos wheels (#3013) 2020-05-11 20:17:26 +03:00
Peter Hawkins
821193b2b7
Explicitly build specific CUDA capabilities. (#2722)
We choose the same set as TensorFlow (minus 3.7, which TF is apparently considering dropping anyway).

This avoids a slow PTX -> SASS compilation on first time startup.
2020-04-15 10:57:53 -04:00
Peter Hawkins
cbdf9a5a43
Drop support for Python 3.5. (#2445) 2020-03-18 10:54:28 -04:00
Peter Hawkins
219d503e71
Don't show progress bar in build script if output is not a terminal. (#2429) 2020-03-16 11:01:08 -04:00
Peter Hawkins
80abdf0c53
Unbreak build and update XLA. (#2289)
* raise minimum Bazel version to 2.0.0 to match TensorFlow.
* set --experimental_repo_remote_exec since it is required by the TF build.
* bump TF/XLA version.
* use the --config=short_logs trick from TF to suppress build warnings.
2020-02-22 09:45:24 -08:00
Peter Hawkins
b6e8341176
Improve developer documentation. (#2247)
Add Python version test to build.py.
2020-02-17 11:24:03 -08:00
Skye Wanderman-Milne
bf91ebf67a
Return error number in build.py on bad bazel version. (#2218)
This prevents our build scripts from continuing on error.
2020-02-12 09:57:54 -08:00
Skye Wanderman-Milne
96a65de6c8
Try downloading bazel before using pre-installed bazel. (#2217)
This ensures we're using the right bazel version.
2020-02-12 09:49:33 -08:00
Peter Hawkins
fe041c7590 Set minimum Bazel version to 1.2.1. 2020-02-03 10:13:51 -05:00
Ruizhe Zhao
8c7fc3919d
Upgrade bazel from 0.29.1 to 1.2.1 (#2137) 2020-02-03 10:12:40 -05:00
Skye Wanderman-Milne
409d057f76
Build CUDA 10.2 jaxlibs. (#2121)
Also adds install_cuda.sh script that sets appropriate nccl and cuDNN versions.
2020-01-29 10:47:17 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
55f2d3be27 Update Jaxlib docker build.
* work around https://github.com/bazelbuild/bazel/issues/9254 by setting BAZEL_LINKLIBS=-lstdc++
* drop CUDA 9.0 support, since we use a batched kernel only present in CUDA 9.2 or later.
* drop Python 2.7 support.
2020-01-28 11:17:21 -05:00
Skye Wanderman-Milne
f1339cd0b0
Remove missing PPA in Dockerfile. (#2061)
This PPA has been removed by the owner: https://launchpad.net/~jonathonf/+archive/ubuntu/python-3.6
This causes `apt-get update` to fail when generating the Docker image. We don't seem to need this repository, so just remove it before calling `apt-get update`.
2020-01-23 17:24:23 -08:00
Skye Wanderman-Milne
b6dfb8bf18
Bump minimum bazel version to 0.26.0. (#2060)
Fixes #2044
2020-01-23 16:46:45 -08:00
Peter Hawkins
64bf55dc6f
Update XLA. (#1997)
Drop six dependency from jaxlib, since xla_client.py no longer uses six.
2020-01-14 11:05:54 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Peter Hawkins
ea91c96a9d
Specify a minimum Mac OS version in builds to avoid backward compatibility problems. (#1807) 2019-12-03 11:59:31 -05:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00
Peter Hawkins
67038321f8
Revert support for building a non-GPU build with --config=cuda enabled. (#1757)
It turns out there are implicit CUDA dependencies inside the TF libraries used by JAX, so the attempt to disable GPU dependencies conditionally didn't work.
2019-11-24 13:06:10 -05:00
Skye Wanderman-Milne
39e17398ee
jaxlib build improvements (#1742) 2019-11-21 16:25:24 -08:00
Peter Hawkins
c601950b1b
Build Mac wheels for Python 3.8 with scipy 1.3.2. (#1739)
scipy 1.3.1 never had a Python 3.8 wheel.
2019-11-21 13:52:08 -05:00
Peter Hawkins
9b853a4255
Update XLA. (#1702)
Add support for building a CPU-only jaxlib with a CUDA-enabled toolchain.
2019-11-16 11:01:36 -05:00
Skye Wanderman-Milne
44ccca05df Revert "Update bazel min version to 0.26.0."
It turns out we can't build TF with this bazel version yet.

This reverts commit 807a1958bb2a276f012901ba6cd9226371099005.
2019-11-13 10:05:58 -08:00
Skye Wanderman-Milne
807a1958bb Update bazel min version to 0.26.0.
I think this was the first release including the --repo_env arg:
d7702b16eb
2019-11-13 09:59:38 -08:00
android
ddbdcfb9c9 Add TPU Driver to jaxlib (#1673) 2019-11-12 18:11:39 -08:00
Peter Hawkins
3e9ce2f69f
Use --repo_env instead of --action_env to configure Python and CUDA. (#1619)
--action_env variables are passed to every build action. This means that if the variable changes, the entire build cache is invalidated. By contrast, --repo_env variables are only passed to repository rules and don't affect every action. In principle this means that we should be able to rebuild JAX for different Python versions without rebuilding 99% of the C++ code.

Update bazel release for build script to 0.29.1 (same as TensorFlow.)
2019-11-01 10:41:51 -04:00
Peter Hawkins
b3e4a1a850
Update jaxlib build scripts to build Python 3.8.0 wheels. (#1612) 2019-10-31 11:46:37 -04:00
Peter Hawkins
affa2dcca4
Increment jax and jaxlib versions. (#1603)
* Update XLA version to 7acd3bb9d7
* Remove XRT reference from jaxlib build.
2019-10-30 15:50:00 -04:00
Peter Hawkins
1428c11a2c Update Jaxlib version to 0.1.29.
Bump XLA version. Enable C++14 mode since it is required by the new XLA version.
2019-09-28 15:11:09 -04:00
Skye Wanderman-Milne
2c6f74dfbc Add CUDA 10.1 wheels. 2019-09-05 14:45:32 -07:00
Peter Hawkins
61713fe52e Update Docker build to produce manylinux2010 compliant wheels for non-cuda builds.
Previously we lied claimed our wheels were manylinux1 compliant but they weren't.

Uses a cross-compilation toolchain from the TF folks that builds manylinux2010 compliant wheels from a Ubuntu 16.04 VM.

The CUDA wheels still aren't manylinux2010 compliant because they depend on CUDA libraries from the system.
2019-08-13 16:25:32 -04:00
Peter Hawkins
38b805720b Update Docker scripts to use a tmpfs for builds. Upgrade bazel. 2019-08-08 16:42:15 -04:00