3232 Commits

Author SHA1 Message Date
Peter Hawkins
cf65f6b24e
Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
2020-09-10 11:16:35 -04:00
Peter Hawkins
b67e42a373
Revert "Revert "Delete batching.last. (#4148)" (#4160)" (#4242)
This reverts commit 36846e0ed96cc613e419ac85d9c3d54a49aa9ebc.
2020-09-10 09:38:14 -04:00
George Necula
0962ceb057
[jax2tf] Fix test failure on TPUs (#4247) 2020-09-10 13:55:57 +03:00
Benjamin Chetioui
29f97afa29
[jax2tf] Cleanup test_unary_elementwise. (#4246)
* [jax2tf] Cleanup test_unary_elementwise.
2020-09-10 12:59:44 +03:00
George Necula
f0a3fd4a86
[jax2tf] Moved the limitations for XlaSort to correctness_stats (#4237) 2020-09-10 12:19:22 +03:00
Adam Paszke
3f8aaabbcc Interrupt lu transformation generators whenever an exception occurs
This fixes some errors that have been appearing in our CI from time to
time. All transformations are implemented as generators, but they
haven't been explicitly aborted when an exception has been raised.
Instead, they only got closed when they got garbage collected, which
could happen at an unspecified later time, potentially leading to a
corruption of global state, which could have been modified after the
exception was handled.

Note that this implementation doesn't propagate the original exception
into the argument transformations, and doesn't allow them to handle the
error either. Such an extension would be possible, but throwing an
exception into a generator mutates the exception object, clobbering
the nice traceback that we would usually carry. One can work around
those issues, but it feels really hacky and we don't need it right now
anyway, so I figured we'll be better off with the simple thing for the
time being.
2020-09-09 20:43:05 +02:00
Benjamin Chetioui
70891f46cd
[jax2tf] Add a template file for documentation generation. (#4219)
* [jax2tf] Add a template file for documentation generation.

The documentation now gives instructions about how to
regenerate it, as well as when it was last generated.

* Added a list of conversions that are not yet implemented.
2020-09-09 17:48:00 +03:00
Benjamin Chetioui
053cd5aa39
[jax2tf] Clean up test_dynamic_slice. (#4236)
* [jax2tf] Clean up test_dynamic_slice.

With the XLA nested compilation bug fixed, this should now work
fine.
2020-09-09 15:43:36 +03:00
Roman Ring
bff24bddbb
Add axis_index_groups support to all_gather. (#4194) 2020-09-09 15:02:45 +03:00
Benjamin Chetioui
f908f6f25c
[jax2tf] Updated test_pad to test all dtypes and remove old (#4235)
skipped test.
2020-09-09 13:20:59 +03:00
George Necula
ee38e7166c
[jax2tf] Clean up code for XlaGather, experimental_compile not necessary (#4030)
* [jax2tf] Clean up code for XlaGather, experimental_compile not necessary

Now that XlaGather has been fixed in XLA, we do not need to use
experimental_compile workaround (which was not working anyway when
put in a SavedModel).

This fix requires a recent tf-nightly installation.
2020-09-09 11:34:22 +03:00
Matthew Johnson
745d90d036
improve lax.pad shape rule (#4234)
It's now:
  * better tested
  * better at catching errors
  * faster
  * easier to read
2020-09-08 21:14:25 -07:00
Matthew Johnson
7f3078b70d
updtate version and changelog for pypi (#4224) 2020-09-08 08:54:13 -07:00
Matthew Johnson
ed0d8c02f6
tweak lax.py shape broadcasting logic (#4217)
This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.)
2020-09-08 08:27:41 -07:00
Benjamin Chetioui
798a2648f5
[jax2tf] Fix bug in population count and move expect_tf_exception (#4214)
into correctness stats.

The code was using `tf.bitcast` instead of `tf.cast`, but using
`expect_tf_exception` in every case was hiding the errors.
2020-09-08 11:32:53 +03:00
Benjamin Chetioui
e1340f3495
[jax2tf] Fix missing complex64 TPU corner case of scatter_{add,mul} (#4213) 2020-09-07 18:12:35 +03:00
Adam Paszke
0aed1f4ddf Add more context to the axis_frame error message.
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.
2020-09-07 16:25:30 +02:00
George Necula
4413bb8a4f
[jax2tf] Do not use jax.random.PRNGKey before in primitive harness (#4211)
We cannot execute JAX functions before the program is initialized
2020-09-07 17:13:11 +03:00
Benjamin Chetioui
be8ea1447f
[jax2tf] Expand coverage of primitives by categorize. (#4209)
* [jax2tf] Expand coverage of primitives by categorize.

This commit adds handling logic for the limitations of:
- qr
- svd
- select_and_gather_add
- reduce_window/reduce_window_{min,max,sum}
- add
- mul
- scatter/scatter_{min,max,mul,add}

Also fixes a bug in a call to _infer_shape_jax, which wasn't
compatible with boolean operands and went undetected due to the
high-level handling of TF exceptions in higher-order primitives.
2020-09-07 16:47:18 +03:00
George Necula
1e84cbe9cc
[jax2tf] Fix random.split when jax_exable_x64 (#4208)
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
2020-09-07 14:41:50 +03:00
Benjamin Chetioui
6c62935d00
[jax2tf] Cleanup the correctness stats layout. (#4201)
* [jax2tf] Cleanup the correctness stats layout.

* Added Google license at the top of the file.
* Cleanup: fix docstring for 80 char boundary.
* Monkey patch/cleanup outside of the loop.
* Removed tensorflow dependency.
* Fixed the name of attributes of Limitation.
2020-09-07 12:03:00 +03:00
George Necula
c6e6ee2dcb
[jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204)
* performance is the same
2020-09-07 11:26:52 +03:00
AdrienCorenflos
96278e67a2
Add reverse flag in associative scan (#4181)
Add optional 'reverse' argument  in associative scan
2020-09-04 09:21:43 -07:00
Benjamin Chetioui
bcf9777bac
[jax2tf] Generator for the documentation of operations with limited support (WIP) (#4193)
* [jax2tf] Draft of a generator for the documentation of operations
with limited support.
2020-09-03 16:56:22 +03:00
George Necula
abdd13884b
[jax2tf] Flip the with_gradient=True; was flipped back by mistake (#4200) 2020-09-03 14:24:04 +03:00
George Necula
5eac47726b
[jax2tf] Implementation of random_gamma (#4192)
* [jax2tf] implementation of random_gamma

The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.

On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
2020-09-03 14:18:35 +03:00
Alex Riley
708d07d5ff
Add jax.numpy.array_split (#4197) 2020-09-02 16:13:17 -07:00
Matthew Johnson
04f9a7e53d
better jax.numpy.tile implementation (#4190)
Use reshape, broadcast_to, reshape.
2020-09-01 18:16:20 -07:00
Jake Vanderplas
421550a979
copysign: promote to inexact to match numpy & support unsigned inputs (#4188) 2020-09-01 15:48:40 -07:00
Benjamin Chetioui
0cdb1f7ee6
[jax2tf] Indicate the version of TF used in tests in README. (#4185) 2020-09-01 10:35:25 +03:00
Jean-Baptiste Lespiau
bdd65453b4
Add more features to the C++ jax.jit. (#4169)
This mainly follows https://github.com/google/jax/pull/4089 by adding:

- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.

I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)

See:

- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
 cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
2020-09-01 10:34:47 +03:00
Jake Vanderplas
36368a2a6d
jnp.abs(): support boolean inputs (#4186) 2020-08-31 14:11:49 -07:00
Hamza Merzić
44bcf7e776
Fix axis checking and remove extra print statement (#4184)
A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes).
2020-08-31 17:00:34 +03:00
George Necula
b6b1f5e349
[jax2tf] Turn on with_gradient by default (#4180)
As I was writing the demo I realized that it makes more sense for
with_gradient to be set to True by default.

I have also fixed a bug with tie_in in omnistaging.
2020-08-31 10:26:32 +03:00
George Necula
634c6259df
More renaming of master to main in JAX internals (#4179) 2020-08-30 12:38:14 +03:00
Jake Vanderplas
ffbfadd83e
lax.associative_scan: fix docstring examples (#4172)
* lax.associative_scan: fix docstring examples
* add verbiage from #3583
2020-08-30 11:36:47 +03:00
Matthew Johnson
6b6789a53b
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-08-30 11:16:51 +03:00
Benjamin Chetioui
1a87fd3bc1
Implement a proper shape checking rule for gather. (#4166)
* Implement a proper shape checking rule for gather.

The implementation is based on the corresponding shape inference
code in `tensorflow/compiler/xla/service/shape_inference.cc`. The
tests added in `tests/lax_test.py` are similarly mirroring the
corresponding tests in tensorflow, with slight adaptations for the
particular setting of JAX. Fixes google/jax#2826, and in principle
fixes google/jax#4154 and google/jax#3905.

* Extracted common functions for gather/scatter shape checking rules.
2020-08-29 11:24:03 +03:00
Adam Paszke
a33f4dd8c8
Add support for axis_index inside vmap (#4168)
Also, reorganize the code to put all `axis_index` related functions in
`lax_parallel.py`, next to all other parallel collectives.
2020-08-28 20:03:39 +02:00
Benjamin Chetioui
04f9ff7ff4
Addition of one more conclusive polynomial comparison case. (#4167)
* Addition of one more conclusive polynomial comparison case.

In the case when the difference between two polynomials is a
constant, it is possible to conclusively compare them. This commit
adds such a case to masking.Poly.__ge__.

* Added a few relevant tests in tests.masking_test.test_Poly_compare.
2020-08-28 17:27:32 +03:00
Adam Paszke
7210d6f5d0 Add support for binding axis_name in gmap
This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.

This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
2020-08-28 14:42:01 +02:00
George Necula
36846e0ed9
Revert "Delete batching.last. (#4148)" (#4160)
This reverts commit 4bf3d6e9cccc5de3834e37affae2012e6e3d3180.

This commit fails internal tests.
2020-08-27 12:45:48 +03:00
Benjamin Chetioui
a7faf09025
[jax2tf] Added conversion for scatter*_p primitives. (#4091)
* [jax2tf] Added conversion for scatter*_p primitives.

Limitations:

the conversion works as well as the conversion of the underlying reduction functions (e.g. lax.scatter_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors);
the conversion can not take advantage of the unique_indices parameter. This does not affect correctness, but may affect performance on certain platforms (as stated in the documentation of lax.scatter).

* Put tf.function experimental compile wrapper back on scatter.
* Removed unique_indices=True test cases
* Remove non-deterministic test cases from the scatter harness.

This commit also documents the reasons for ignoring these test
cases and potential pitfalls, in case someone needs to perform
these tests at a later time.
2020-08-27 12:24:13 +03:00
Benjamin Chetioui
4d7396aa02
Implement a proper shape checking rule for scatter. (#4144)
The implementation is based on the corresponding shape inference
code in `tensorflow/compiler/xla/service/shape_inference.cc`. The
tests added in `tests/lax_test.py` are similarly mirroring the
corresponding tests in tensorflow, with slight adaptations for
the particular setting of JAX.
2020-08-27 12:04:32 +03:00
Benjamin Chetioui
80114e51d6
Add a boolean to _check_shapelike to accept or reject shapes (#4108)
* Add a boolean to _check_shapelike to accept or reject shapes
corresponding to arrays of 0 elements. (Fixes google/jax#3972).

* Added test for failures referenced in issue 3972.
2020-08-27 10:47:19 +03:00
Benjamin Chetioui
1dc71b2f41
[jax2tf] Add testing for add/mul/min/max conversion. (#4142)
* [jax2tf] Add testing for add/mul/min/max conversion.

Only certain types are supported for each of the operations above.
This commit adds previously missing tests to make this explicit.
2020-08-27 10:46:32 +03:00
George Necula
57f49b68a6
Fix bug in omnistaging_enabler (#4159)
This code was failing with "KeyError: psum" for the tests
"//third_party/py/flax/...". I suspect that the error is due to the
ordering of the omnistaging enablers, changed in #4152.

I am not sure of this fix, but this seemed to be enough for all the
presubmit tests to pass and allow the copybara import.
2020-08-27 10:05:24 +03:00
George Necula
417c9ff764
Fix pytype error (#4158) 2020-08-27 09:41:16 +03:00
Jake Vanderplas
29073be0ab
cleanup: remove duplicate line (#4156) 2020-08-26 21:13:33 -07:00
Tom Hennigan
f0fb7d0925
Use omnistaging env var even when not using absl flags for config. (#4152) 2020-08-26 14:06:27 -07:00