8508 Commits

Author SHA1 Message Date
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
jax authors
6aa20d8f8f Merge pull request #7294 from hawkinsp:py36
PiperOrigin-RevId: 384994957
2021-07-15 13:19:23 -07:00
jax authors
e951ac8521 Merge pull request #7296 from apaszke:xmap-flag
PiperOrigin-RevId: 384993575
2021-07-15 13:12:38 -07:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
jax authors
efd37e9099 Merge pull request #7280 from tomhennigan:changelist/384668789
PiperOrigin-RevId: 384938599
2021-07-15 09:04:41 -07:00
Tom Hennigan
d6e56f2df9 Add source location and expression to error messages for CUDA API calls.
Before:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: operation not supported

After:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported
2021-07-15 15:42:46 +00:00
jax authors
ccfc6f6281 Merge pull request #7254 from hawkinsp:crosscompile
PiperOrigin-RevId: 384930242
2021-07-15 08:22:52 -07:00
Peter Hawkins
f5c61a892a Add support for cross-compiling jaxlib for Mac ARM. 2021-07-15 10:37:53 -04:00
Adam Paszke
e987f6f9fc Make maps.EXPERIMENTAL_SPMD_LOWERING into a jax.config flag
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.

Second attempt, this time without hardening against the flags being
registered too late due to delayed imports.
2021-07-15 14:18:52 +00:00
jax authors
25e44821dd Make maps.EXPERIMENTAL_SPMD_LOWERING into a jax.config flag
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.

PiperOrigin-RevId: 384902895
2021-07-15 05:07:09 -07:00
Adam Paszke
64510bd5b6 Add axis and tiled options to lax.all_gather.
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.

PiperOrigin-RevId: 384897270
2021-07-15 04:22:36 -07:00
Adam Paszke
8bc6e7f1d5 Make maps.EXPERIMENTAL_SPMD_LOWERING into a jax.config flag
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.

PiperOrigin-RevId: 384892199
2021-07-15 03:37:30 -07:00
jax authors
4d026e06b1 Merge pull request #7255 from jakevdp:remove-broadcast-p
PiperOrigin-RevId: 384888218
2021-07-15 03:12:07 -07:00
jax authors
b9aaff4870 Parse absl flags in compilation_cache_test.py
PiperOrigin-RevId: 384744462
2021-07-14 11:26:22 -07:00
jax authors
2621b4b629 Merge pull request #7281 from apaszke:mapped-error-msg
PiperOrigin-RevId: 384707157
2021-07-14 08:43:54 -07:00
jax authors
4c6a06245b Merge pull request #7272 from jakevdp:jnp-rank-promotion
PiperOrigin-RevId: 384686748
2021-07-14 06:41:19 -07:00
Adam Paszke
1049e7205f Implement vmap rules for with_sharding_constraint
pjit already supports batching, so there's no need to hold off on that.

PiperOrigin-RevId: 384684263
2021-07-14 06:25:20 -07:00
Adam Paszke
1c1ec79edd Clarify the error message for out-of-bounds in_axes in pmap and vmap
Fixes #5201.
2021-07-14 12:11:06 +00:00
George Necula
79c8259e91 Add shape inference rule for XlaDynamicSlice
PiperOrigin-RevId: 384628638
2021-07-13 23:04:25 -07:00
jax authors
5488cf004e Merge pull request #7275 from jakevdp:sparse-xla-call
PiperOrigin-RevId: 384570089
2021-07-13 15:56:44 -07:00
jax authors
9907ae4516 Merge pull request #7207 from colemanliyah:compilation_cache
PiperOrigin-RevId: 384569690
2021-07-13 15:53:02 -07:00
Jake VanderPlas
1eb3b5f8d6 [sparse] support sparse arguments in xla_call 2021-07-13 15:23:14 -07:00
jax authors
035c6d5755 Merge pull request #7261 from skye:tpu_docker
PiperOrigin-RevId: 384559874
2021-07-13 15:03:08 -07:00
Liyah Coleman
052062f163 initialize_cache, get_executable, and put_executable functions 2021-07-13 21:50:36 +00:00
jax authors
c64c5a6116 Merge pull request #7268 from hawkinsp:platform
PiperOrigin-RevId: 384552186
2021-07-13 14:28:57 -07:00
jax authors
344d521847 Merge pull request #7267 from hawkinsp:nonccl
PiperOrigin-RevId: 384549897
2021-07-13 14:17:33 -07:00
jax authors
618ea255f9 Merge pull request #7274 from jakevdp:sparsify-xla-call
PiperOrigin-RevId: 384544540
2021-07-13 13:53:15 -07:00
Jake VanderPlas
9af8676341 [sparse] support dense xla_call within sparsify jaxpr interpreter 2021-07-13 13:31:21 -07:00
Jake VanderPlas
0ddcace9c6 lax_numpy_test: disable implicit rank promotion by default 2021-07-13 11:38:21 -07:00
jax authors
d4c5abc563 Merge pull request #7253 from zhangqiaorjc:test_matrix_latest_jaxlib
PiperOrigin-RevId: 384494974
2021-07-13 10:26:35 -07:00
jax authors
a8193156a6 Merge pull request #7270 from jakevdp:fix-cpu-notebook
PiperOrigin-RevId: 384486998
2021-07-13 09:53:08 -07:00
Jake VanderPlas
5a0a46cbef Update colab CPU test notebook to be more robust 2021-07-13 09:30:56 -07:00
Peter Hawkins
f33ce0d844 Warn if importing jaxlib on Mac ARM machines.
We can remove this warning when Mac ARM has CI testing.
2021-07-13 09:24:48 -04:00
Peter Hawkins
7d2aec105f Add an option to disable NCCL. 2021-07-13 09:10:29 -04:00
jax authors
208dd1ac3f Merge pull request #7252 from jakevdp:cov-rank-promotion
PiperOrigin-RevId: 384442693
2021-07-13 05:48:51 -07:00
jax authors
c6d614843c Merge pull request #7258 from jakevdp:fix-gamma
PiperOrigin-RevId: 384442646
2021-07-13 05:48:28 -07:00
jax authors
885fa23c76 Merge pull request #7231 from jakevdp:pep-448
PiperOrigin-RevId: 384442628
2021-07-13 05:44:35 -07:00
jax authors
840781d2a8 Merge pull request #7263 from gnecula:shape_poly_dim
PiperOrigin-RevId: 384422182
2021-07-13 03:01:24 -07:00
George Necula
caba2ed9b8 [jax2tf] Support tf.compat.v1.Dimension when parsing polymorphic_shapes 2021-07-13 10:39:38 +03:00
Skye Wanderman-Milne
b2fd6a772b Changes to make jax[tpu] work better in a docker container.
1. In cloud_tpu_init.py, check whether we're on a Cloud TPU VM by
   looking for the libtpu Python package, instead of /lib/libtpu.so
   (which isn't necessarily present in a docker container). JAX now
   relies on the libtpu package instead of the system libtpu.so, so
   this makes more sense either way. This means we'll try/catch an
   ImportError in all non-TPU environments when importing jax, which
   hopefully isn't noticeably slow.

2. Add requests as a jax[tpu] dependency, since it's needed by
   cloud_tpu_init.py. This comes pre-installed on Cloud TPU VMs, but
   may not be installed in docker containers, virtualenvs, etc.

I manually tested by creating the following Dockerfile on a Cloud TPU VM:
```
FROM ubuntu:18.04
RUN apt update && apt install git python3-pip -y
RUN git clone https://github.com/skye/jax && cd jax && git checkout tpu_docker
WORKDIR jax
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
CMD ["python3", "-c", "import jax; print(jax.device_count())"]
```

And then running the following commands:
```
$ sudo docker build -t jax-test .
$ sudo docker run --privileged jax-test
8
```

Note the `--privileged` flags is necessary to let the container access
the TPU devices in /dev.
2021-07-12 17:42:46 -07:00
jax authors
05a13aa011 Merge pull request #7234 from jakevdp:sparse-while
PiperOrigin-RevId: 384350476
2021-07-12 17:10:02 -07:00
Jake VanderPlas
5db97e0bf9 [sparse] add sparse transform rule for lax.while_p 2021-07-12 16:53:24 -07:00
Jake VanderPlas
c45acd70a8 Cleanup: use pep 448 unpacking to simplify some code 2021-07-12 16:30:53 -07:00
jax authors
10569871b7 Merge pull request #7260 from skye:release_indexes
PiperOrigin-RevId: 384341784
2021-07-12 16:27:47 -07:00
Skye Wanderman-Milne
1a650d2e50 Update generate_release_index[es].py to also produce libtpu_releases.html.
Previously, the libtpu-nightly wheels were included in the same index
file as the jaxlib wheels (jax_releases.html). This caused issues
because it would cause `pip install jax[tpu] -f jaxlib_releases.html`
to install a cuda jaxlib, instead of the regular CPU/TPU jaxlib from
pypi.

Instead, we create a separate index file for the libtpu-nightly
wheels, so `pip install jax[tpu] -f libtpu_releases.html` still uses
the jaxlib from pypi.

This also renames generate_release_index.py to generate_release_indexes.py.
2021-07-12 16:02:54 -07:00
Jake VanderPlas
12e435f71e remove lax.broadcast_p
Why? It has been subsumed by lax.broadcast_in_dim_p
2021-07-12 15:33:26 -07:00
jax authors
8a044c1886 Merge pull request #7259 from zhangqiaorjc:repro_cpu_bugs
PiperOrigin-RevId: 384331007
2021-07-12 15:30:14 -07:00
Qiao Zhang
0216a317a0 Change test matrix to use latest jaxlib release. 2021-07-12 15:04:10 -07:00
Qiao Zhang
72b436f9ed Add a test to repro bugs in TFRT CPU backend. 2021-07-12 14:53:49 -07:00
Jake VanderPlas
9e73972d0a Fix jax.scipy.stats.gamma.pdf() for x=0.0, a=1.0 2021-07-12 14:43:44 -07:00