342 Commits

Author SHA1 Message Date
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
30980742c5
refine population_count type check (#3887)
* refine population_count type check

fixes #3886

* allow signed/unsigned ints for population_count

https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/shape_inference.cc;l=314?q=xla%20f:shape_inference.cc

* make lax_reference.population_count handle signed
2020-07-28 19:46:00 -07:00
Jamie Townsend
e28db33b01
Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883 (#3888)
* Add test for issue 3883

* Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883
2020-07-28 18:39:32 -07:00
Peter Hawkins
7da1cba66a
Remove fallback to 2-pass algorithm for argmin/argmax on TPU. (#3831)
(The compiler problem that prompted the workaround seems to be fixed.)
2020-07-23 11:35:29 -04:00
James Bradbury
f574b11499
Support preconditions on scatter indices (#3147)
* wire through precondition flags to XLA scatter

* use scatter precondition flags in tests

* fix DUS batching rule

* make new arguments kw-only

* onp -> np

* fix jax2tf for new args

* fix more test failures
2020-07-21 23:16:27 -07:00
Peter Hawkins
a6e2d20b31
Add support for base dilation and window dilation to reduce window op… (#3803) 2020-07-20 17:27:24 -04:00
Stephan Hoyer
fe99a06ddf
Error message and docstring updates RE: dynamic_slice (#3795)
This should clarify the underlying issues from #1007 and #3794.

It might be worth mentioning masking, but that's a little big for fitting into
an error message. Maybe once the masking transformation is non-experimental or
if we had a dedicated doc page.
2020-07-20 09:08:54 -04:00
Peter Hawkins
e2e73a854a Relax dimension ordering rules for dot_general.
JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions.

Since batch dimensions are now allowed in arbitrary positions, it's not hard to
generalize the dot_general batching rule to avoid performing any transposes
(#2972).

In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`.
2020-07-16 19:36:22 -04:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp in library code (#3754) 2020-07-14 13:05:31 -07:00
James Bradbury
f78ccf1c39
Fix tuple() in reduce_window padding (#3748)
* Fix tuple() in reduce_window padding

* Update lax.py
2020-07-13 18:16:11 -07:00
James Bradbury
6017205cea
Add defensive tuple() in lax.reduce_window (#3741) 2020-07-13 14:37:46 -07:00
Peter Hawkins
3c6cd5fb94
Implement complex convolutions on CPU and GPU. (#3735)
Lowers using Gauss's complex multiplication algorithm (which internally is also what the XLA:TPU implementation does.)
2020-07-13 14:44:24 -04:00
Peter Hawkins
71253ac4c1
Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
* Generalize reduce-window padding to support (lo, hi) pairs, as XLA does..

This turns out to simplify the code slightly, too.

* Fix select_and_gather_add batching rule and test.

* Fix documentation text to refer to ReduceWindowWithGeneralPadding.
2020-07-13 09:49:52 -04:00
Peter Hawkins
a9da06ce75
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729) 2020-07-13 09:43:19 -04:00
Matthew Johnson
51ca57d5fc
check matmul inputs aren't scalar (#3725)
also dot_general shape rule should check dimension numbers are in range

fixes #3718
2020-07-11 20:47:22 -07:00
Jake Vanderplas
60d852773e
lexicographic sort_p: accept num_keys rather than comparator (#3715) 2020-07-10 09:58:35 -07:00
Jake Vanderplas
d2f9c46a0c
Remove some non-inclusive language (#3710) 2020-07-10 09:29:06 -07:00
Jake Vanderplas
804e449389
Generalize lax.sort to support lexicographic sorts. (#3709) 2020-07-09 20:05:19 -07:00
Roman Novak
4442c333ce
Add support for 0d transpose convolution (#3643)
* Allow 0d transpose convolution

* Add test for 0d conv transpose

* remove whitespace
2020-07-02 14:38:35 -07:00
Matthew Johnson
65c4d755de
fix bug in categorical test, disable #3611 on tpu (#3633)
* fix bug in categorical test, disable #3611 on tpu

Disabling #3611 on TPU pending a TPU compilation bug.

* unskip a test
2020-07-01 14:15:48 -07:00
Peter Hawkins
141fabbbf5
Reimplement argmin/argmax using a single pass variadic reduction. (#3611) 2020-07-01 11:01:22 -04:00
Matthew Johnson
eb2a227588
fix reduction repeated axis error (#3618)
* fix reduction repeated axis error

* deflake
2020-06-30 21:18:46 -07:00
Jake Vanderplas
db8f66d508
Rework type support for lax cumulative reductions (#3609) 2020-06-30 11:36:27 -07:00
Peter Hawkins
420ef4e0a8
Fix shape rule for lax.pad for input dimensions of size 0. (#3608) 2020-06-30 12:07:38 -04:00
Erich Elsen
aa6585f995 bool -> bool_ for reasons that make no sense, (bool used to be any?!) 2020-06-29 19:20:19 +01:00
Erich Elsen
b46bd2301c add support bool identity values 2020-06-29 19:13:41 +01:00
Erich Elsen
77a023df48 change ending tick mark style 2020-06-29 18:13:36 +01:00
Erich Elsen
491fcbb202 floating point identity to inf 2020-06-29 00:50:14 +01:00
Erich Elsen
b8d0de6365 remove trailing whitespace 2020-06-28 21:33:42 +01:00
Erich Elsen
290d608e9d remove now unneeded type def 2020-06-28 20:41:48 +01:00
Erich Elsen
1f15ffc45f consolidate jvp rule definitions 2020-06-28 20:39:20 +01:00
Erich Elsen
a98249d766 actually return the primitive 2020-06-28 20:31:30 +01:00
Erich Elsen
a189737ecb add generic reducer primitive generator and replace prod/max/min with it. 2020-06-28 20:28:31 +01:00
Erich Elsen
d3f6d85da5 remove unit and determine automatically for all ops 2020-06-28 20:21:35 +01:00
Erich Elsen
4fe9c1d624 fix other branch 2020-06-28 20:14:14 +01:00
Erich Elsen
1e33e5346e account for different names of reducer in tpu function 2020-06-28 20:10:27 +01:00
Erich Elsen
294d6f893f Also update custom tpu rule to set unit correctly based on dtype 2020-06-28 20:06:43 +01:00
Erich Elsen
a54a38f691 Add default value of None for unit in TPU impl of cummax/cummin 2020-06-28 19:53:47 +01:00
Erich Elsen
e2fa89dbec onp.finfo -> jnp.finfo for bfloat16 2020-06-28 19:49:36 +01:00
Erich Elsen
ae9e6851cc use correct iinfo finfo names 2020-06-28 19:44:36 +01:00
Erich Elsen
812d246295 don't require passing identity value. It isn't the initial value - identity is required for implementation correctness 2020-06-28 19:33:20 +01:00
Erich Elsen
95e15b64e3 fix typo 2020-06-28 18:37:50 +01:00
Erich Elsen
bf06633a87 add tests 2020-06-28 18:21:09 +01:00
Roy Frostig
ccb640afdb lax.sort: stable by default 2020-06-26 20:37:23 -07:00
Matthew Johnson
11caa21eca
ensure lax.reduce monoid test uses original numpy (#3573) 2020-06-26 11:44:16 -07:00
Norman Casagrande
99a43f20db
Added missing is_stable argument to lax.sort (#3553) 2020-06-26 10:40:00 -07:00
Jamie Townsend
c9670d50c5
Fix lazy broadcast issue (#3536) 2020-06-25 07:50:11 -07:00
Jake Vanderplas
d5a5d301f2
lax.sort: allow any sequence of Arrays, not just tuples (#3367) 2020-06-23 08:28:04 -07:00
Srinivas Vasudevan
927c209148
Add random_gamma_grad and use in jax.random.gamma (#3281) 2020-06-19 09:34:18 -04:00
Jacob Kelly
575216e094
add jet primitives, refactor tests (#3468)
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-06-16 19:48:25 -07:00