81 Commits

Author SHA1 Message Date
Matthew Johnson
93841df822 fix lax.imag jvp and enable test, fixes #979 2019-07-05 14:32:04 -07:00
Peter Hawkins
a06ba06f97 Update comments. 2019-07-02 13:23:05 -04:00
Peter Hawkins
165df6204b Simplify reduce-precision logic.
Enable TPU gradient tests only up to order 1. The first-order JVP of reduce-window tests select_and_scatter_add, which is the part changed by this PR.
2019-07-02 11:34:49 -04:00
Peter Hawkins
40560d2c9a Refactor select_and_gather_add implementation to improve readability.
Change implementation to use ReducePrecision to perform half-word reductions.
2019-07-01 22:26:36 -04:00
Peter Hawkins
db369091a2 Add support for higher derivatives of reduce-window-min/max at reduced precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.
Warn if we are forced to be imprecise.
2019-06-28 20:27:10 -04:00
Peter Hawkins
3e914e17b0 Improve documentation for precision. 2019-06-28 14:06:24 -04:00
Peter Hawkins
bca27fea8b Simplify precision specification: only allow a single precision for an entire operator. 2019-06-28 12:48:44 -04:00
Peter Hawkins
0af9da7662 Add precision option to lax dot and conv APIs.
Set a default precision of "highest" in LU decomposition.
Enable a number of dot and conv tests on TPU under highest precision.
Enable linalg tests that use LU decomposition on TPU.
2019-06-28 10:00:39 -04:00
Peter Hawkins
3b4521b1f6 Enable convolutions for non float32 types. 2019-06-27 17:17:49 -04:00
Peter Hawkins
990c2df123 Implement a pure Python LU decomposition that can be used on platforms where we do not otherwise have a better implementation.
Restructure xla.lower_fun and trace_unwrapped_to_jaxpr so the instantiate option can be passed to them, separately from any function arguments.
2019-06-27 14:50:29 -04:00
Peter Hawkins
c7afa1eb34
Merge pull request #930 from hawkinsp/master
Merge reduce-window batching rules. Add rule for reduce_window_min.
2019-06-26 13:40:12 -04:00
Peter Hawkins
755d2818e8 Merge reduce-window batching rules. Add batching rule for reduce_window_min. 2019-06-26 10:19:42 -04:00
Matthew Johnson
a663486148 fixes from rebase onto master 2019-06-24 19:45:18 -07:00
Matthew Johnson
40bbf068d8 fix broadcasting papply rule, move to lax_parallel 2019-06-24 19:39:12 -07:00
Matthew Johnson
d64188bcb6 del serial_pmap, simpler papply, add parallelize
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.

The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.

This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.

Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
  doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
  broadcasting_papply rule and plumbing the `size` argument to papply
  rules
* remove psplit primitive and psplit_like primitives and replace it with
  calls to all_to_all where needed
2019-06-24 19:38:26 -07:00
Roy Frostig
33b01733a9
Merge pull request #916 from google/parallelize
parallelization work-in-progress
2019-06-24 16:08:14 -07:00
Roy Frostig
15bc966567 Merge branch 'master' into parallelize 2019-06-24 11:32:59 -07:00
Peter Hawkins
c1bec691c5 Avoid instantiating zeros in dynamic_slice/gather transpose rules. 2019-06-24 13:44:49 -04:00
Justin Lebar
d5ba04b79e Add jax.ops.index_min/max.
These are analogous to index_add.
2019-06-21 19:33:34 -07:00
John Schulman
8b163628fd map stop_gradient over data structure. otherwise it is silently a no-op 2019-06-20 16:23:13 -07:00
Matthew Johnson
8bc4e379f5 make DeviceArray.__hash__ raise an error
Fixes #883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.

Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
2019-06-19 10:12:13 -07:00
Peter Hawkins
04676e4e95 Use _dtype instead of .dtype in dynamic slice transpose rule. 2019-06-18 10:25:10 -04:00
Peter Hawkins
c8b946e1fd Use lax.full instead of broadcast in transpose rule for dynamic_slice. 2019-06-18 10:18:33 -04:00
Peter Hawkins
b4acfe0640
Merge pull request #868 from hawkinsp/master
Implement batching for np.linalg.solve
2019-06-18 05:58:10 -06:00
Matthew Johnson
6dd5423666 manually fuse a transpose into a reshape 2019-06-17 19:39:14 -07:00
Matthew Johnson
3ce87f7874
Merge pull request #859 from jheek/grouped-convs
Grouped convs
2019-06-17 19:25:19 -07:00
Matthew Johnson
a6e374ddb9
Merge pull request #858 from google/improve-conv-batching
improve efficiency of conv batching rules (i.e. vmap rules)
2019-06-17 19:23:07 -07:00
Peter Hawkins
991a5a9f4c Use _dtype instead of .dtype in dynamic slice rule. 2019-06-17 21:39:34 -04:00
Peter Hawkins
4e872b96b9 Fix type mismatch with int32-type indices under a jit with 64-bit types enabled. 2019-06-17 21:18:27 -04:00
Peter Hawkins
0190684ee2
Merge pull request #866 from hawkinsp/master
Implement np.ix_, for non-bool inputs.
2019-06-17 18:39:45 -06:00
Peter Hawkins
ec685bf8ae Implement np.ix_, for non-bool inputs. 2019-06-17 17:08:27 -04:00
Jonathan Heek
e3462fd8b1 write out batch_feature_groups to simplify and correct implementation 2019-06-17 21:32:03 +02:00
Matthew Johnson
a1219f10e5
Merge pull request #865 from google/fix-reshape-grad-bug
fix special-form reshape transpose bug (and add tests)
2019-06-17 12:29:45 -07:00
Matthew Johnson
fef68deef6 fix reshape transpose bug (and add tests)
This version of reshape (taking a `dimensions` argument, which
effectively fuses in a transpose) seems only to be used in the JVP rule
for lax._reduce_prod (basically np.product), but its transpose rule was
totally busted and untested.
2019-06-17 11:54:36 -07:00
Roy Frostig
a113d9f5c9 Merge branch 'master' into parallelize 2019-06-17 09:09:39 -07:00
Peter Hawkins
129e73a258 Use lax.full to create zeros array in gather transpose rule. 2019-06-17 08:19:36 -04:00
Jonathan Heek
077d56529f grouped convolution support 2019-06-17 12:37:19 +02:00
Matthew Johnson
ff29d582e8 t # This is a combination of 2 commits.
make all conv vmap rules generate a single call

also plumb feature_group_count and batch_group_count everywhere
2019-06-16 07:50:19 -07:00
Matthew Johnson
1262ca9b30 improve conv rhs batching, add systematic test 2019-06-15 12:01:20 -07:00
James Bradbury
edfe52035b Fix dot batch rule bug and add test + check 2019-06-12 18:02:01 -07:00
Roy Frostig
49672f79b8 parallelization rule for lax.gather
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-06-06 18:51:25 -07:00
Peter Hawkins
bd389b7fcf Add bug number for integer dots. 2019-06-06 19:04:11 -04:00
Peter Hawkins
15361c4dde Add support for integer dot operations.
Lower to a sum of products for integers since XLA currently lacks support for integer dots.
2019-06-06 17:24:43 -04:00
Matthew Johnson
1feefd10ac
Merge pull request #820 from google/simplify-gather-shape-rule
simplify gather shape rule
2019-06-05 19:14:27 -07:00
Matthew Johnson
9e49500c54 simplify gather shape rule
Co-authored-by: Roy Frostig <frostig@google.com>
2019-06-05 17:06:42 -07:00
James Bradbury
9bc5f2aecf fix bug in dot batching rule 2019-06-05 15:17:06 -07:00
Roy Frostig
6cec008a71 Merge branch 'master' into parallelize 2019-06-04 15:31:27 -07:00
Peter Hawkins
e63bd4d592 Add domain test to atanh implementation.
Disable some tests on tpu.
2019-05-31 17:22:19 -04:00
Roy Frostig
771564024d Merge branch 'master' into parallelize 2019-05-31 11:41:31 -07:00
Roy Frostig
800bdca858 sketch of parallelization rule for lax.conv_general_dilated 2019-05-31 11:40:51 -07:00