5174 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
6dc161cf27
Pin pygments version in RTD build. (#4267)
This fixes our RTD failures, which were caused by RTD installing an older version of pygments:
```
jupyterlab-pygments 0.1.1 requires pygments<3,>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
nbconvert 6.0.1 requires pygments>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
```
2020-09-11 11:16:54 -07:00
Alvaro
ca1d8f4109
Fixing weird behavior in segment_sum when num_segments is None (#4034)
Co-authored-by: alvarosg <alvarosg@google.com>
2020-09-11 13:51:42 -04:00
Jake Vanderplas
3b7329c92e
Call check_arraylike() in jax.numpy reductions (#4195) 2020-09-11 10:04:47 -07:00
Adam Paszke
40fb01b4bd Extend axis env while translating the pmapped jaxpr to XLA
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
2020-09-11 17:56:32 +02:00
Jake Vanderplas
8a18b10fa4
implement jnp.apply_along_axis (#4253) 2020-09-11 08:47:05 -07:00
Qiao Zhang
7e694bd3b6
Update README to point to jaxlib-0.1.55. (#4256) 2020-09-10 18:28:44 -07:00
Skye Wanderman-Milne
fc39332c4f
Clarify when jaxlib version should be bumped. (#4250) 2020-09-10 15:53:04 -07:00
Jake Vanderplas
2a33b3d388
fix documentation typo (#4252) 2020-09-10 11:23:29 -07:00
Qiao Zhang
82af356b4c
Bump TF hash to get an upstream LLVM GCC fix. (#4251) jaxlib-v0.1.55 2020-09-10 10:10:33 -07:00
Peter Hawkins
cf65f6b24e
Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
2020-09-10 11:16:35 -04:00
Peter Hawkins
b67e42a373
Revert "Revert "Delete batching.last. (#4148)" (#4160)" (#4242)
This reverts commit 36846e0ed96cc613e419ac85d9c3d54a49aa9ebc.
2020-09-10 09:38:14 -04:00
Qiao Zhang
26a53ae554
Add comments for residuals from f_bwd. (#4244) 2020-09-10 13:58:28 +03:00
George Necula
0962ceb057
[jax2tf] Fix test failure on TPUs (#4247) 2020-09-10 13:55:57 +03:00
Benjamin Chetioui
29f97afa29
[jax2tf] Cleanup test_unary_elementwise. (#4246)
* [jax2tf] Cleanup test_unary_elementwise.
2020-09-10 12:59:44 +03:00
George Necula
f0a3fd4a86
[jax2tf] Moved the limitations for XlaSort to correctness_stats (#4237) 2020-09-10 12:19:22 +03:00
Qiao Zhang
adb344880b
Reorder nocuda/cuda build to fail early. (#4243) 2020-09-09 17:23:43 -07:00
Qiao Zhang
0b04439f11
Update install_cuda script to specify cublas. (#4240) 2020-09-09 17:16:58 -07:00
Qiao Zhang
a14133aa33
Update TF dep to a passing commit hash. (#4239) 2020-09-09 16:36:30 -07:00
Adam Paszke
3f8aaabbcc Interrupt lu transformation generators whenever an exception occurs
This fixes some errors that have been appearing in our CI from time to
time. All transformations are implemented as generators, but they
haven't been explicitly aborted when an exception has been raised.
Instead, they only got closed when they got garbage collected, which
could happen at an unspecified later time, potentially leading to a
corruption of global state, which could have been modified after the
exception was handled.

Note that this implementation doesn't propagate the original exception
into the argument transformations, and doesn't allow them to handle the
error either. Such an extension would be possible, but throwing an
exception into a generator mutates the exception object, clobbering
the nice traceback that we would usually carry. One can work around
those issues, but it feels really hacky and we don't need it right now
anyway, so I figured we'll be better off with the simple thing for the
time being.
2020-09-09 20:43:05 +02:00
Benjamin Chetioui
70891f46cd
[jax2tf] Add a template file for documentation generation. (#4219)
* [jax2tf] Add a template file for documentation generation.

The documentation now gives instructions about how to
regenerate it, as well as when it was last generated.

* Added a list of conversions that are not yet implemented.
2020-09-09 17:48:00 +03:00
Benjamin Chetioui
053cd5aa39
[jax2tf] Clean up test_dynamic_slice. (#4236)
* [jax2tf] Clean up test_dynamic_slice.

With the XLA nested compilation bug fixed, this should now work
fine.
2020-09-09 15:43:36 +03:00
Roman Ring
bff24bddbb
Add axis_index_groups support to all_gather. (#4194) 2020-09-09 15:02:45 +03:00
Benjamin Chetioui
f908f6f25c
[jax2tf] Updated test_pad to test all dtypes and remove old (#4235)
skipped test.
2020-09-09 13:20:59 +03:00
George Necula
ee38e7166c
[jax2tf] Clean up code for XlaGather, experimental_compile not necessary (#4030)
* [jax2tf] Clean up code for XlaGather, experimental_compile not necessary

Now that XlaGather has been fixed in XLA, we do not need to use
experimental_compile workaround (which was not working anyway when
put in a SavedModel).

This fix requires a recent tf-nightly installation.
2020-09-09 11:34:22 +03:00
Qiao Zhang
3cf7336753
Fix Dockerfile wheel installation issues. (#4232) 2020-09-08 21:28:34 -07:00
Matthew Johnson
745d90d036
improve lax.pad shape rule (#4234)
It's now:
  * better tested
  * better at catching errors
  * faster
  * easier to read
2020-09-08 21:14:25 -07:00
Skye Wanderman-Milne
cf2d15d4bb
jaxlib build fixes. (#4066)
1. `wheel.pep425tags` has been removed as of
   https://github.com/pypa/setuptools/pull/1829. Use the new
   `packaging.tags` instead.

2. Add `--allow-downgrades` to cuda install command. I'm not sure this
   is always necessary, but I ran into it, I'm guessing due to a cached
   docker image.
2020-09-08 18:23:42 -07:00
Qiao Zhang
4600dd7957
Update jaxlib version for dlpack fix. (#4231) 2020-09-08 17:20:48 -07:00
Jake Vanderplas
9a70be2419
Add test for dtype coverage of jax.numpy ufuncs (#3913) 2020-09-08 13:30:57 -07:00
Matthew Johnson
7f3078b70d
updtate version and changelog for pypi (#4224) jax-v0.1.76 2020-09-08 08:54:13 -07:00
Matthew Johnson
ed0d8c02f6
tweak lax.py shape broadcasting logic (#4217)
This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.)
2020-09-08 08:27:41 -07:00
Benjamin Chetioui
798a2648f5
[jax2tf] Fix bug in population count and move expect_tf_exception (#4214)
into correctness stats.

The code was using `tf.bitcast` instead of `tf.cast`, but using
`expect_tf_exception` in every case was hiding the errors.
2020-09-08 11:32:53 +03:00
Benjamin Chetioui
e1340f3495
[jax2tf] Fix missing complex64 TPU corner case of scatter_{add,mul} (#4213) 2020-09-07 18:12:35 +03:00
Adam Paszke
0aed1f4ddf Add more context to the axis_frame error message.
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.
2020-09-07 16:25:30 +02:00
George Necula
4413bb8a4f
[jax2tf] Do not use jax.random.PRNGKey before in primitive harness (#4211)
We cannot execute JAX functions before the program is initialized
2020-09-07 17:13:11 +03:00
Benjamin Chetioui
be8ea1447f
[jax2tf] Expand coverage of primitives by categorize. (#4209)
* [jax2tf] Expand coverage of primitives by categorize.

This commit adds handling logic for the limitations of:
- qr
- svd
- select_and_gather_add
- reduce_window/reduce_window_{min,max,sum}
- add
- mul
- scatter/scatter_{min,max,mul,add}

Also fixes a bug in a call to _infer_shape_jax, which wasn't
compatible with boolean operands and went undetected due to the
high-level handling of TF exceptions in higher-order primitives.
2020-09-07 16:47:18 +03:00
George Necula
1e84cbe9cc
[jax2tf] Fix random.split when jax_exable_x64 (#4208)
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
2020-09-07 14:41:50 +03:00
Benjamin Chetioui
6c62935d00
[jax2tf] Cleanup the correctness stats layout. (#4201)
* [jax2tf] Cleanup the correctness stats layout.

* Added Google license at the top of the file.
* Cleanup: fix docstring for 80 char boundary.
* Monkey patch/cleanup outside of the loop.
* Removed tensorflow dependency.
* Fixed the name of attributes of Limitation.
2020-09-07 12:03:00 +03:00
George Necula
c6e6ee2dcb
[jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204)
* performance is the same
2020-09-07 11:26:52 +03:00
AdrienCorenflos
96278e67a2
Add reverse flag in associative scan (#4181)
Add optional 'reverse' argument  in associative scan
2020-09-04 09:21:43 -07:00
Benjamin Chetioui
bcf9777bac
[jax2tf] Generator for the documentation of operations with limited support (WIP) (#4193)
* [jax2tf] Draft of a generator for the documentation of operations
with limited support.
2020-09-03 16:56:22 +03:00
George Necula
abdd13884b
[jax2tf] Flip the with_gradient=True; was flipped back by mistake (#4200) 2020-09-03 14:24:04 +03:00
George Necula
5eac47726b
[jax2tf] Implementation of random_gamma (#4192)
* [jax2tf] implementation of random_gamma

The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.

On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
2020-09-03 14:18:35 +03:00
Alex Riley
708d07d5ff
Add jax.numpy.array_split (#4197) 2020-09-02 16:13:17 -07:00
Matthew Johnson
04f9a7e53d
better jax.numpy.tile implementation (#4190)
Use reshape, broadcast_to, reshape.
2020-09-01 18:16:20 -07:00
Jake Vanderplas
421550a979
copysign: promote to inexact to match numpy & support unsigned inputs (#4188) 2020-09-01 15:48:40 -07:00
Benjamin Chetioui
0cdb1f7ee6
[jax2tf] Indicate the version of TF used in tests in README. (#4185) 2020-09-01 10:35:25 +03:00
Jean-Baptiste Lespiau
bdd65453b4
Add more features to the C++ jax.jit. (#4169)
This mainly follows https://github.com/google/jax/pull/4089 by adding:

- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.

I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)

See:

- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
 cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
2020-09-01 10:34:47 +03:00
Jake Vanderplas
36368a2a6d
jnp.abs(): support boolean inputs (#4186) 2020-08-31 14:11:49 -07:00
Hamza Merzić
44bcf7e776
Fix axis checking and remove extra print statement (#4184)
A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes).
2020-08-31 17:00:34 +03:00