328 Commits

Author SHA1 Message Date
Matthew Johnson
51ca57d5fc
check matmul inputs aren't scalar (#3725)
also dot_general shape rule should check dimension numbers are in range

fixes #3718
2020-07-11 20:47:22 -07:00
Jake Vanderplas
60d852773e
lexicographic sort_p: accept num_keys rather than comparator (#3715) 2020-07-10 09:58:35 -07:00
Jake Vanderplas
d2f9c46a0c
Remove some non-inclusive language (#3710) 2020-07-10 09:29:06 -07:00
Jake Vanderplas
804e449389
Generalize lax.sort to support lexicographic sorts. (#3709) 2020-07-09 20:05:19 -07:00
Roman Novak
4442c333ce
Add support for 0d transpose convolution (#3643)
* Allow 0d transpose convolution

* Add test for 0d conv transpose

* remove whitespace
2020-07-02 14:38:35 -07:00
Matthew Johnson
65c4d755de
fix bug in categorical test, disable #3611 on tpu (#3633)
* fix bug in categorical test, disable #3611 on tpu

Disabling #3611 on TPU pending a TPU compilation bug.

* unskip a test
2020-07-01 14:15:48 -07:00
Peter Hawkins
141fabbbf5
Reimplement argmin/argmax using a single pass variadic reduction. (#3611) 2020-07-01 11:01:22 -04:00
Matthew Johnson
eb2a227588
fix reduction repeated axis error (#3618)
* fix reduction repeated axis error

* deflake
2020-06-30 21:18:46 -07:00
Jake Vanderplas
db8f66d508
Rework type support for lax cumulative reductions (#3609) 2020-06-30 11:36:27 -07:00
Peter Hawkins
420ef4e0a8
Fix shape rule for lax.pad for input dimensions of size 0. (#3608) 2020-06-30 12:07:38 -04:00
Erich Elsen
aa6585f995 bool -> bool_ for reasons that make no sense, (bool used to be any?!) 2020-06-29 19:20:19 +01:00
Erich Elsen
b46bd2301c add support bool identity values 2020-06-29 19:13:41 +01:00
Erich Elsen
77a023df48 change ending tick mark style 2020-06-29 18:13:36 +01:00
Erich Elsen
491fcbb202 floating point identity to inf 2020-06-29 00:50:14 +01:00
Erich Elsen
b8d0de6365 remove trailing whitespace 2020-06-28 21:33:42 +01:00
Erich Elsen
290d608e9d remove now unneeded type def 2020-06-28 20:41:48 +01:00
Erich Elsen
1f15ffc45f consolidate jvp rule definitions 2020-06-28 20:39:20 +01:00
Erich Elsen
a98249d766 actually return the primitive 2020-06-28 20:31:30 +01:00
Erich Elsen
a189737ecb add generic reducer primitive generator and replace prod/max/min with it. 2020-06-28 20:28:31 +01:00
Erich Elsen
d3f6d85da5 remove unit and determine automatically for all ops 2020-06-28 20:21:35 +01:00
Erich Elsen
4fe9c1d624 fix other branch 2020-06-28 20:14:14 +01:00
Erich Elsen
1e33e5346e account for different names of reducer in tpu function 2020-06-28 20:10:27 +01:00
Erich Elsen
294d6f893f Also update custom tpu rule to set unit correctly based on dtype 2020-06-28 20:06:43 +01:00
Erich Elsen
a54a38f691 Add default value of None for unit in TPU impl of cummax/cummin 2020-06-28 19:53:47 +01:00
Erich Elsen
e2fa89dbec onp.finfo -> jnp.finfo for bfloat16 2020-06-28 19:49:36 +01:00
Erich Elsen
ae9e6851cc use correct iinfo finfo names 2020-06-28 19:44:36 +01:00
Erich Elsen
812d246295 don't require passing identity value. It isn't the initial value - identity is required for implementation correctness 2020-06-28 19:33:20 +01:00
Erich Elsen
95e15b64e3 fix typo 2020-06-28 18:37:50 +01:00
Erich Elsen
bf06633a87 add tests 2020-06-28 18:21:09 +01:00
Roy Frostig
ccb640afdb lax.sort: stable by default 2020-06-26 20:37:23 -07:00
Matthew Johnson
11caa21eca
ensure lax.reduce monoid test uses original numpy (#3573) 2020-06-26 11:44:16 -07:00
Norman Casagrande
99a43f20db
Added missing is_stable argument to lax.sort (#3553) 2020-06-26 10:40:00 -07:00
Jamie Townsend
c9670d50c5
Fix lazy broadcast issue (#3536) 2020-06-25 07:50:11 -07:00
Jake Vanderplas
d5a5d301f2
lax.sort: allow any sequence of Arrays, not just tuples (#3367) 2020-06-23 08:28:04 -07:00
Srinivas Vasudevan
927c209148
Add random_gamma_grad and use in jax.random.gamma (#3281) 2020-06-19 09:34:18 -04:00
Jacob Kelly
575216e094
add jet primitives, refactor tests (#3468)
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-06-16 19:48:25 -07:00
Adam Paszke
4d40b208ed
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.

The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
  that can be inverted correctly if only we iterate over its equations
  in reverse. This is a bit strict, because users generally don't have
  too much control over that, and there are functions that produce
  jaxprs which will be treated as invertible when one topological
  ordering of equations is used, while they will be considered
  non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
  that zero argument pruning in JVPTrace makes it pretty much impossible
  to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
  that all the VJPs of primal primitives are jittable (not sure if
  that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
  `custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
Stephan Hoyer
3deada9ede
Document valid enum values for precision. (#3441)
This is a little tricky to figure out otherwise.
2020-06-14 21:42:45 -07:00
Jake Vanderplas
71461a37f3
Revert "Initial implementation of variadic lax.reduce() (#3342)" (#3384)
This reverts commit 99401c5a844cc19c6ce66cc26997f999c9ecf6d8.
2020-06-09 16:09:50 -04:00
Jake Vanderplas
99401c5a84
Initial implementation of variadic lax.reduce() (#3342) 2020-06-09 09:22:29 -07:00
Matthew Johnson
866c17c32e fix a couple ad_util.Zero type checks 2020-06-08 13:22:13 -07:00
Adam Paszke
3f1d3a73ac Remove example from ad.instantiate_zeros, fix vmap bug 2020-06-05 15:52:01 +00:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Jake Vanderplas
b187663a87
deflake jax/lax & add to flake8 check (#3310) 2020-06-04 13:50:44 -07:00
Roy Frostig
6015a2a689 introduce lax.switch 2020-06-03 22:19:15 -07:00
Skye Wanderman-Milne
5ad9feda5f
Fix handling of infeed token inside sharded_jit (#3313) 2020-06-03 15:23:49 -07:00
Julius Kunze
d1dbf7c7d8
Implement mask for some primitives + jit. (#2922)
* Implement mask for slice, conv, pad, transpose, where

* Remove tentative mask(jit)

* Add explanatory comment to dot_general masking rule

* Rm reshape from select masking rule

* Rm unnecessary check from lax slice abstract_eval rule

* Revert to standard indentation in masking_test.py

* Begin simplifying masking tests

* Finish drafting masking check function

* More progress simplifying tests

* Add conv masking in batch dim

* Finish fixing up tests

* Revert to old API, making out_shape compulsory again

* More efficient conv masking rule

* Tidy up masking_test imports

* Check that out tree is preserved by masking

* fix flake errors

Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-03 13:40:48 -07:00
Jake Vanderplas
0db57cb541
Fix validation code in lax.conv (#3279) 2020-06-03 10:33:19 -07:00
Peter Hawkins
dd81a8dded
Fix some type errors in lax.py found by pytype. (#3292) 2020-06-02 10:27:14 -04:00
Jamie Townsend
3909875f9d
Improve speed of tracing dynamic_update_slice (#3247)
* Improve tracing performance of _dynamic_slice_indices

* More precisely preserve semantics of dynamic_slice_indices

* Use safe_map in dynamic_slice_indices
2020-06-02 09:37:32 -04:00