Matthew Johnson
93841df822
fix lax.imag jvp and enable test, fixes #979
2019-07-05 14:32:04 -07:00
Peter Hawkins
a06ba06f97
Update comments.
2019-07-02 13:23:05 -04:00
Peter Hawkins
165df6204b
Simplify reduce-precision logic.
...
Enable TPU gradient tests only up to order 1. The first-order JVP of reduce-window tests select_and_scatter_add, which is the part changed by this PR.
2019-07-02 11:34:49 -04:00
Peter Hawkins
40560d2c9a
Refactor select_and_gather_add implementation to improve readability.
...
Change implementation to use ReducePrecision to perform half-word reductions.
2019-07-01 22:26:36 -04:00
Peter Hawkins
db369091a2
Add support for higher derivatives of reduce-window-min/max at reduced precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.
...
Warn if we are forced to be imprecise.
2019-06-28 20:27:10 -04:00
Peter Hawkins
3e914e17b0
Improve documentation for precision
.
2019-06-28 14:06:24 -04:00
Peter Hawkins
bca27fea8b
Simplify precision specification: only allow a single precision for an entire operator.
2019-06-28 12:48:44 -04:00
Peter Hawkins
0af9da7662
Add precision
option to lax dot and conv APIs.
...
Set a default precision of "highest" in LU decomposition.
Enable a number of dot and conv tests on TPU under highest precision.
Enable linalg tests that use LU decomposition on TPU.
2019-06-28 10:00:39 -04:00
Peter Hawkins
3b4521b1f6
Enable convolutions for non float32 types.
2019-06-27 17:17:49 -04:00
Peter Hawkins
990c2df123
Implement a pure Python LU decomposition that can be used on platforms where we do not otherwise have a better implementation.
...
Restructure xla.lower_fun and trace_unwrapped_to_jaxpr so the instantiate option can be passed to them, separately from any function arguments.
2019-06-27 14:50:29 -04:00
Peter Hawkins
c7afa1eb34
Merge pull request #930 from hawkinsp/master
...
Merge reduce-window batching rules. Add rule for reduce_window_min.
2019-06-26 13:40:12 -04:00
Peter Hawkins
755d2818e8
Merge reduce-window batching rules. Add batching rule for reduce_window_min.
2019-06-26 10:19:42 -04:00
Matthew Johnson
a663486148
fixes from rebase onto master
2019-06-24 19:45:18 -07:00
Matthew Johnson
40bbf068d8
fix broadcasting papply rule, move to lax_parallel
2019-06-24 19:39:12 -07:00
Matthew Johnson
d64188bcb6
del serial_pmap, simpler papply, add parallelize
...
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.
The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.
This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.
Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
broadcasting_papply rule and plumbing the `size` argument to papply
rules
* remove psplit primitive and psplit_like primitives and replace it with
calls to all_to_all where needed
2019-06-24 19:38:26 -07:00
Roy Frostig
33b01733a9
Merge pull request #916 from google/parallelize
...
parallelization work-in-progress
2019-06-24 16:08:14 -07:00
Roy Frostig
15bc966567
Merge branch 'master' into parallelize
2019-06-24 11:32:59 -07:00
Peter Hawkins
c1bec691c5
Avoid instantiating zeros in dynamic_slice/gather transpose rules.
2019-06-24 13:44:49 -04:00
Justin Lebar
d5ba04b79e
Add jax.ops.index_min/max.
...
These are analogous to index_add.
2019-06-21 19:33:34 -07:00
John Schulman
8b163628fd
map stop_gradient over data structure. otherwise it is silently a no-op
2019-06-20 16:23:13 -07:00
Matthew Johnson
8bc4e379f5
make DeviceArray.__hash__ raise an error
...
Fixes #883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.
Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
2019-06-19 10:12:13 -07:00
Peter Hawkins
04676e4e95
Use _dtype instead of .dtype in dynamic slice transpose rule.
2019-06-18 10:25:10 -04:00
Peter Hawkins
c8b946e1fd
Use lax.full instead of broadcast in transpose rule for dynamic_slice.
2019-06-18 10:18:33 -04:00
Peter Hawkins
b4acfe0640
Merge pull request #868 from hawkinsp/master
...
Implement batching for np.linalg.solve
2019-06-18 05:58:10 -06:00
Matthew Johnson
6dd5423666
manually fuse a transpose into a reshape
2019-06-17 19:39:14 -07:00
Matthew Johnson
3ce87f7874
Merge pull request #859 from jheek/grouped-convs
...
Grouped convs
2019-06-17 19:25:19 -07:00
Matthew Johnson
a6e374ddb9
Merge pull request #858 from google/improve-conv-batching
...
improve efficiency of conv batching rules (i.e. vmap rules)
2019-06-17 19:23:07 -07:00
Peter Hawkins
991a5a9f4c
Use _dtype instead of .dtype in dynamic slice rule.
2019-06-17 21:39:34 -04:00
Peter Hawkins
4e872b96b9
Fix type mismatch with int32-type indices under a jit with 64-bit types enabled.
2019-06-17 21:18:27 -04:00
Peter Hawkins
0190684ee2
Merge pull request #866 from hawkinsp/master
...
Implement np.ix_, for non-bool inputs.
2019-06-17 18:39:45 -06:00
Peter Hawkins
ec685bf8ae
Implement np.ix_, for non-bool inputs.
2019-06-17 17:08:27 -04:00
Jonathan Heek
e3462fd8b1
write out batch_feature_groups to simplify and correct implementation
2019-06-17 21:32:03 +02:00
Matthew Johnson
a1219f10e5
Merge pull request #865 from google/fix-reshape-grad-bug
...
fix special-form reshape transpose bug (and add tests)
2019-06-17 12:29:45 -07:00
Matthew Johnson
fef68deef6
fix reshape transpose bug (and add tests)
...
This version of reshape (taking a `dimensions` argument, which
effectively fuses in a transpose) seems only to be used in the JVP rule
for lax._reduce_prod (basically np.product), but its transpose rule was
totally busted and untested.
2019-06-17 11:54:36 -07:00
Roy Frostig
a113d9f5c9
Merge branch 'master' into parallelize
2019-06-17 09:09:39 -07:00
Peter Hawkins
129e73a258
Use lax.full to create zeros array in gather transpose rule.
2019-06-17 08:19:36 -04:00
Jonathan Heek
077d56529f
grouped convolution support
2019-06-17 12:37:19 +02:00
Matthew Johnson
ff29d582e8
t # This is a combination of 2 commits.
...
make all conv vmap rules generate a single call
also plumb feature_group_count and batch_group_count everywhere
2019-06-16 07:50:19 -07:00
Matthew Johnson
1262ca9b30
improve conv rhs batching, add systematic test
2019-06-15 12:01:20 -07:00
James Bradbury
edfe52035b
Fix dot batch rule bug and add test + check
2019-06-12 18:02:01 -07:00
Roy Frostig
49672f79b8
parallelization rule for lax.gather
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-06-06 18:51:25 -07:00
Peter Hawkins
bd389b7fcf
Add bug number for integer dots.
2019-06-06 19:04:11 -04:00
Peter Hawkins
15361c4dde
Add support for integer dot operations.
...
Lower to a sum of products for integers since XLA currently lacks support for integer dots.
2019-06-06 17:24:43 -04:00
Matthew Johnson
1feefd10ac
Merge pull request #820 from google/simplify-gather-shape-rule
...
simplify gather shape rule
2019-06-05 19:14:27 -07:00
Matthew Johnson
9e49500c54
simplify gather shape rule
...
Co-authored-by: Roy Frostig <frostig@google.com>
2019-06-05 17:06:42 -07:00
James Bradbury
9bc5f2aecf
fix bug in dot batching rule
2019-06-05 15:17:06 -07:00
Roy Frostig
6cec008a71
Merge branch 'master' into parallelize
2019-06-04 15:31:27 -07:00
Peter Hawkins
e63bd4d592
Add domain test to atanh implementation.
...
Disable some tests on tpu.
2019-05-31 17:22:19 -04:00
Roy Frostig
771564024d
Merge branch 'master' into parallelize
2019-05-31 11:41:31 -07:00
Roy Frostig
800bdca858
sketch of parallelization rule for lax.conv_general_dilated
2019-05-31 11:40:51 -07:00