5145 Commits

Author SHA1 Message Date
Matthew Johnson
7f3078b70d
updtate version and changelog for pypi (#4224) jax-v0.1.76 2020-09-08 08:54:13 -07:00
Matthew Johnson
ed0d8c02f6
tweak lax.py shape broadcasting logic (#4217)
This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.)
2020-09-08 08:27:41 -07:00
Benjamin Chetioui
798a2648f5
[jax2tf] Fix bug in population count and move expect_tf_exception (#4214)
into correctness stats.

The code was using `tf.bitcast` instead of `tf.cast`, but using
`expect_tf_exception` in every case was hiding the errors.
2020-09-08 11:32:53 +03:00
Benjamin Chetioui
e1340f3495
[jax2tf] Fix missing complex64 TPU corner case of scatter_{add,mul} (#4213) 2020-09-07 18:12:35 +03:00
Adam Paszke
0aed1f4ddf Add more context to the axis_frame error message.
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.
2020-09-07 16:25:30 +02:00
George Necula
4413bb8a4f
[jax2tf] Do not use jax.random.PRNGKey before in primitive harness (#4211)
We cannot execute JAX functions before the program is initialized
2020-09-07 17:13:11 +03:00
Benjamin Chetioui
be8ea1447f
[jax2tf] Expand coverage of primitives by categorize. (#4209)
* [jax2tf] Expand coverage of primitives by categorize.

This commit adds handling logic for the limitations of:
- qr
- svd
- select_and_gather_add
- reduce_window/reduce_window_{min,max,sum}
- add
- mul
- scatter/scatter_{min,max,mul,add}

Also fixes a bug in a call to _infer_shape_jax, which wasn't
compatible with boolean operands and went undetected due to the
high-level handling of TF exceptions in higher-order primitives.
2020-09-07 16:47:18 +03:00
George Necula
1e84cbe9cc
[jax2tf] Fix random.split when jax_exable_x64 (#4208)
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
2020-09-07 14:41:50 +03:00
Benjamin Chetioui
6c62935d00
[jax2tf] Cleanup the correctness stats layout. (#4201)
* [jax2tf] Cleanup the correctness stats layout.

* Added Google license at the top of the file.
* Cleanup: fix docstring for 80 char boundary.
* Monkey patch/cleanup outside of the loop.
* Removed tensorflow dependency.
* Fixed the name of attributes of Limitation.
2020-09-07 12:03:00 +03:00
George Necula
c6e6ee2dcb
[jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204)
* performance is the same
2020-09-07 11:26:52 +03:00
AdrienCorenflos
96278e67a2
Add reverse flag in associative scan (#4181)
Add optional 'reverse' argument  in associative scan
2020-09-04 09:21:43 -07:00
Benjamin Chetioui
bcf9777bac
[jax2tf] Generator for the documentation of operations with limited support (WIP) (#4193)
* [jax2tf] Draft of a generator for the documentation of operations
with limited support.
2020-09-03 16:56:22 +03:00
George Necula
abdd13884b
[jax2tf] Flip the with_gradient=True; was flipped back by mistake (#4200) 2020-09-03 14:24:04 +03:00
George Necula
5eac47726b
[jax2tf] Implementation of random_gamma (#4192)
* [jax2tf] implementation of random_gamma

The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.

On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
2020-09-03 14:18:35 +03:00
Alex Riley
708d07d5ff
Add jax.numpy.array_split (#4197) 2020-09-02 16:13:17 -07:00
Matthew Johnson
04f9a7e53d
better jax.numpy.tile implementation (#4190)
Use reshape, broadcast_to, reshape.
2020-09-01 18:16:20 -07:00
Jake Vanderplas
421550a979
copysign: promote to inexact to match numpy & support unsigned inputs (#4188) 2020-09-01 15:48:40 -07:00
Benjamin Chetioui
0cdb1f7ee6
[jax2tf] Indicate the version of TF used in tests in README. (#4185) 2020-09-01 10:35:25 +03:00
Jean-Baptiste Lespiau
bdd65453b4
Add more features to the C++ jax.jit. (#4169)
This mainly follows https://github.com/google/jax/pull/4089 by adding:

- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.

I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)

See:

- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
 cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
2020-09-01 10:34:47 +03:00
Jake Vanderplas
36368a2a6d
jnp.abs(): support boolean inputs (#4186) 2020-08-31 14:11:49 -07:00
Hamza Merzić
44bcf7e776
Fix axis checking and remove extra print statement (#4184)
A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes).
2020-08-31 17:00:34 +03:00
George Necula
b6b1f5e349
[jax2tf] Turn on with_gradient by default (#4180)
As I was writing the demo I realized that it makes more sense for
with_gradient to be set to True by default.

I have also fixed a bug with tie_in in omnistaging.
2020-08-31 10:26:32 +03:00
George Necula
634c6259df
More renaming of master to main in JAX internals (#4179) 2020-08-30 12:38:14 +03:00
Jake Vanderplas
ffbfadd83e
lax.associative_scan: fix docstring examples (#4172)
* lax.associative_scan: fix docstring examples
* add verbiage from #3583
2020-08-30 11:36:47 +03:00
Matthew Johnson
6b6789a53b
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-08-30 11:16:51 +03:00
Benjamin Chetioui
1a87fd3bc1
Implement a proper shape checking rule for gather. (#4166)
* Implement a proper shape checking rule for gather.

The implementation is based on the corresponding shape inference
code in `tensorflow/compiler/xla/service/shape_inference.cc`. The
tests added in `tests/lax_test.py` are similarly mirroring the
corresponding tests in tensorflow, with slight adaptations for the
particular setting of JAX. Fixes google/jax#2826, and in principle
fixes google/jax#4154 and google/jax#3905.

* Extracted common functions for gather/scatter shape checking rules.
2020-08-29 11:24:03 +03:00
Adam Paszke
a33f4dd8c8
Add support for axis_index inside vmap (#4168)
Also, reorganize the code to put all `axis_index` related functions in
`lax_parallel.py`, next to all other parallel collectives.
2020-08-28 20:03:39 +02:00
Jake Vanderplas
1dab791acb
Avoid calling jnp.sum() on list (#4163) 2020-08-28 09:07:30 -07:00
Benjamin Chetioui
04f9ff7ff4
Addition of one more conclusive polynomial comparison case. (#4167)
* Addition of one more conclusive polynomial comparison case.

In the case when the difference between two polynomials is a
constant, it is possible to conclusively compare them. This commit
adds such a case to masking.Poly.__ge__.

* Added a few relevant tests in tests.masking_test.test_Poly_compare.
2020-08-28 17:27:32 +03:00
Adam Paszke
7210d6f5d0 Add support for binding axis_name in gmap
This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.

This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
2020-08-28 14:42:01 +02:00
Jean-Baptiste Lespiau
e95d5701e3
Add benchmarks for specifically the dispatch time. (#4128)
The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result.
We also add one to distinguish the time it takes to call the function with the argument transfer or without it.

e.g.

name                                   time/op
jit_trivial_dispatch                   28.9µs ± 2%
jit_trivial                            31.5µs ± 5%
jit_simple_dispatch                    60.7µs ± 4%
jit_simple                              129µs ±24%
jit_simple_many_args_disptch            390µs ±19%
jit_simple_many_args                    388µs ±16%
jit_dispatch_without_transfer           379µs ± 6%
jit_dispatch_with_transfer              450µs ± 5%
2020-08-27 17:02:13 +03:00
George Necula
36846e0ed9
Revert "Delete batching.last. (#4148)" (#4160)
This reverts commit 4bf3d6e9cccc5de3834e37affae2012e6e3d3180.

This commit fails internal tests.
2020-08-27 12:45:48 +03:00
Benjamin Chetioui
a7faf09025
[jax2tf] Added conversion for scatter*_p primitives. (#4091)
* [jax2tf] Added conversion for scatter*_p primitives.

Limitations:

the conversion works as well as the conversion of the underlying reduction functions (e.g. lax.scatter_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors);
the conversion can not take advantage of the unique_indices parameter. This does not affect correctness, but may affect performance on certain platforms (as stated in the documentation of lax.scatter).

* Put tf.function experimental compile wrapper back on scatter.
* Removed unique_indices=True test cases
* Remove non-deterministic test cases from the scatter harness.

This commit also documents the reasons for ignoring these test
cases and potential pitfalls, in case someone needs to perform
these tests at a later time.
2020-08-27 12:24:13 +03:00
Benjamin Chetioui
4d7396aa02
Implement a proper shape checking rule for scatter. (#4144)
The implementation is based on the corresponding shape inference
code in `tensorflow/compiler/xla/service/shape_inference.cc`. The
tests added in `tests/lax_test.py` are similarly mirroring the
corresponding tests in tensorflow, with slight adaptations for
the particular setting of JAX.
2020-08-27 12:04:32 +03:00
Benjamin Chetioui
80114e51d6
Add a boolean to _check_shapelike to accept or reject shapes (#4108)
* Add a boolean to _check_shapelike to accept or reject shapes
corresponding to arrays of 0 elements. (Fixes google/jax#3972).

* Added test for failures referenced in issue 3972.
2020-08-27 10:47:19 +03:00
Benjamin Chetioui
1dc71b2f41
[jax2tf] Add testing for add/mul/min/max conversion. (#4142)
* [jax2tf] Add testing for add/mul/min/max conversion.

Only certain types are supported for each of the operations above.
This commit adds previously missing tests to make this explicit.
2020-08-27 10:46:32 +03:00
George Necula
c76b84f6e2
Revert "Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080)" (#4151)
This reverts commit 22b92c5122ab5af6f5e4560f9be08f5649ae7653.

We revert this change because the LLVM bug that made us relax the
test tolerance is now fixed.
2020-08-27 10:34:53 +03:00
George Necula
57f49b68a6
Fix bug in omnistaging_enabler (#4159)
This code was failing with "KeyError: psum" for the tests
"//third_party/py/flax/...". I suspect that the error is due to the
ordering of the omnistaging enablers, changed in #4152.

I am not sure of this fix, but this seemed to be enough for all the
presubmit tests to pass and allow the copybara import.
2020-08-27 10:05:24 +03:00
George Necula
417c9ff764
Fix pytype error (#4158) 2020-08-27 09:41:16 +03:00
Jake Vanderplas
29073be0ab
cleanup: remove duplicate line (#4156) 2020-08-26 21:13:33 -07:00
Tom Hennigan
f0fb7d0925
Use omnistaging env var even when not using absl flags for config. (#4152) 2020-08-26 14:06:27 -07:00
Matthew Johnson
1d93991003
allow random.choice to accept ndarray input (#4145)
* allow random.choice to accept ndarray `a`

follow-up to #4137 to allow ndarray inputs to be passed

* add jax.random.choice tests to cover ndarray input

* don't use callables in test params

it can mess with pytest-xdist because of hashing by id
2020-08-26 10:21:56 -07:00
Peter Hawkins
01319fb63d
Speed up and clean up geomspace test. (#4149)
* Speed up and clean up geomspace test.
2020-08-25 13:05:06 -04:00
Peter Hawkins
4bf3d6e9cc
Delete batching.last. (#4148)
A -1 axis works just as well at head.
2020-08-25 12:53:18 -04:00
Peter Hawkins
8c8060e130
Remove workaround for illegal vmap out_axes. (#4147) 2020-08-25 12:53:02 -04:00
Jake Vanderplas
6d54eb563e
Do not call asarray() on inputs of jax.random.choice (#4137) 2020-08-25 05:47:43 -07:00
Jean-Baptiste Lespiau
f959219acb
Rename collectives into "collective operations" for the pmap function. (#4136)
It is just because it serves as the entry point, and this term leads to good Google results, such as https://en.wikipedia.org/wiki/Collective_operation, while the current "collectives" do not.
2020-08-25 05:39:45 -07:00
Matthew Johnson
f4b05bc9ea
make pe.abstract_eval_fun use omnistaging (#4139) 2020-08-25 05:38:41 -07:00
Matthew Johnson
04173b3345
Merge pull request #4140 from sharadmv/patch-2
Remove frame check assertion in `extend_axis_env`.
2020-08-25 05:38:20 -07:00
Sharad Vikram
774b5f688e Remove frame check assertion in extend_axis_env. 2020-08-24 21:13:30 -07:00