77 Commits

Author SHA1 Message Date
Peter Hawkins
db369091a2 Add support for higher derivatives of reduce-window-min/max at reduced precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.
Warn if we are forced to be imprecise.
2019-06-28 20:27:10 -04:00
Peter Hawkins
3e914e17b0 Improve documentation for precision. 2019-06-28 14:06:24 -04:00
Peter Hawkins
bca27fea8b Simplify precision specification: only allow a single precision for an entire operator. 2019-06-28 12:48:44 -04:00
Peter Hawkins
0af9da7662 Add precision option to lax dot and conv APIs.
Set a default precision of "highest" in LU decomposition.
Enable a number of dot and conv tests on TPU under highest precision.
Enable linalg tests that use LU decomposition on TPU.
2019-06-28 10:00:39 -04:00
Peter Hawkins
3b4521b1f6 Enable convolutions for non float32 types. 2019-06-27 17:17:49 -04:00
Peter Hawkins
990c2df123 Implement a pure Python LU decomposition that can be used on platforms where we do not otherwise have a better implementation.
Restructure xla.lower_fun and trace_unwrapped_to_jaxpr so the instantiate option can be passed to them, separately from any function arguments.
2019-06-27 14:50:29 -04:00
Peter Hawkins
c7afa1eb34
Merge pull request #930 from hawkinsp/master
Merge reduce-window batching rules. Add rule for reduce_window_min.
2019-06-26 13:40:12 -04:00
Peter Hawkins
755d2818e8 Merge reduce-window batching rules. Add batching rule for reduce_window_min. 2019-06-26 10:19:42 -04:00
Matthew Johnson
a663486148 fixes from rebase onto master 2019-06-24 19:45:18 -07:00
Matthew Johnson
40bbf068d8 fix broadcasting papply rule, move to lax_parallel 2019-06-24 19:39:12 -07:00
Matthew Johnson
d64188bcb6 del serial_pmap, simpler papply, add parallelize
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.

The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.

This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.

Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
  doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
  broadcasting_papply rule and plumbing the `size` argument to papply
  rules
* remove psplit primitive and psplit_like primitives and replace it with
  calls to all_to_all where needed
2019-06-24 19:38:26 -07:00
Roy Frostig
33b01733a9
Merge pull request #916 from google/parallelize
parallelization work-in-progress
2019-06-24 16:08:14 -07:00
Roy Frostig
15bc966567 Merge branch 'master' into parallelize 2019-06-24 11:32:59 -07:00
Peter Hawkins
c1bec691c5 Avoid instantiating zeros in dynamic_slice/gather transpose rules. 2019-06-24 13:44:49 -04:00
Justin Lebar
d5ba04b79e Add jax.ops.index_min/max.
These are analogous to index_add.
2019-06-21 19:33:34 -07:00
John Schulman
8b163628fd map stop_gradient over data structure. otherwise it is silently a no-op 2019-06-20 16:23:13 -07:00
Matthew Johnson
8bc4e379f5 make DeviceArray.__hash__ raise an error
Fixes #883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.

Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
2019-06-19 10:12:13 -07:00
Peter Hawkins
04676e4e95 Use _dtype instead of .dtype in dynamic slice transpose rule. 2019-06-18 10:25:10 -04:00
Peter Hawkins
c8b946e1fd Use lax.full instead of broadcast in transpose rule for dynamic_slice. 2019-06-18 10:18:33 -04:00
Peter Hawkins
b4acfe0640
Merge pull request #868 from hawkinsp/master
Implement batching for np.linalg.solve
2019-06-18 05:58:10 -06:00
Matthew Johnson
6dd5423666 manually fuse a transpose into a reshape 2019-06-17 19:39:14 -07:00
Matthew Johnson
3ce87f7874
Merge pull request #859 from jheek/grouped-convs
Grouped convs
2019-06-17 19:25:19 -07:00
Matthew Johnson
a6e374ddb9
Merge pull request #858 from google/improve-conv-batching
improve efficiency of conv batching rules (i.e. vmap rules)
2019-06-17 19:23:07 -07:00
Peter Hawkins
991a5a9f4c Use _dtype instead of .dtype in dynamic slice rule. 2019-06-17 21:39:34 -04:00
Peter Hawkins
4e872b96b9 Fix type mismatch with int32-type indices under a jit with 64-bit types enabled. 2019-06-17 21:18:27 -04:00
Peter Hawkins
0190684ee2
Merge pull request #866 from hawkinsp/master
Implement np.ix_, for non-bool inputs.
2019-06-17 18:39:45 -06:00
Peter Hawkins
ec685bf8ae Implement np.ix_, for non-bool inputs. 2019-06-17 17:08:27 -04:00
Jonathan Heek
e3462fd8b1 write out batch_feature_groups to simplify and correct implementation 2019-06-17 21:32:03 +02:00
Matthew Johnson
a1219f10e5
Merge pull request #865 from google/fix-reshape-grad-bug
fix special-form reshape transpose bug (and add tests)
2019-06-17 12:29:45 -07:00
Matthew Johnson
fef68deef6 fix reshape transpose bug (and add tests)
This version of reshape (taking a `dimensions` argument, which
effectively fuses in a transpose) seems only to be used in the JVP rule
for lax._reduce_prod (basically np.product), but its transpose rule was
totally busted and untested.
2019-06-17 11:54:36 -07:00
Roy Frostig
a113d9f5c9 Merge branch 'master' into parallelize 2019-06-17 09:09:39 -07:00
Peter Hawkins
129e73a258 Use lax.full to create zeros array in gather transpose rule. 2019-06-17 08:19:36 -04:00
Jonathan Heek
077d56529f grouped convolution support 2019-06-17 12:37:19 +02:00
Matthew Johnson
ff29d582e8 t # This is a combination of 2 commits.
make all conv vmap rules generate a single call

also plumb feature_group_count and batch_group_count everywhere
2019-06-16 07:50:19 -07:00
Matthew Johnson
1262ca9b30 improve conv rhs batching, add systematic test 2019-06-15 12:01:20 -07:00
James Bradbury
edfe52035b Fix dot batch rule bug and add test + check 2019-06-12 18:02:01 -07:00
Roy Frostig
49672f79b8 parallelization rule for lax.gather
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-06-06 18:51:25 -07:00
Peter Hawkins
bd389b7fcf Add bug number for integer dots. 2019-06-06 19:04:11 -04:00
Peter Hawkins
15361c4dde Add support for integer dot operations.
Lower to a sum of products for integers since XLA currently lacks support for integer dots.
2019-06-06 17:24:43 -04:00
Matthew Johnson
1feefd10ac
Merge pull request #820 from google/simplify-gather-shape-rule
simplify gather shape rule
2019-06-05 19:14:27 -07:00
Matthew Johnson
9e49500c54 simplify gather shape rule
Co-authored-by: Roy Frostig <frostig@google.com>
2019-06-05 17:06:42 -07:00
James Bradbury
9bc5f2aecf fix bug in dot batching rule 2019-06-05 15:17:06 -07:00
Roy Frostig
6cec008a71 Merge branch 'master' into parallelize 2019-06-04 15:31:27 -07:00
Peter Hawkins
e63bd4d592 Add domain test to atanh implementation.
Disable some tests on tpu.
2019-05-31 17:22:19 -04:00
Roy Frostig
771564024d Merge branch 'master' into parallelize 2019-05-31 11:41:31 -07:00
Roy Frostig
800bdca858 sketch of parallelization rule for lax.conv_general_dilated 2019-05-31 11:40:51 -07:00
Matthew Johnson
25262edb92 Make pmap lax.psum(1, 'i') and pxla.axis_index('i') work
The implementation mechanism is to use a bit of dynamic context to model
the axis name environment at trace time, and for the environment to
track how an axis name maps to an axis size and the corresponding trace
(i.e. the JaxprTrace instance). With that information, we can lift
special primitives that take axis_name parameters into the trace as
needed without having a data dependence on the input.
2019-05-29 20:13:07 -07:00
Peter Hawkins
0293ecbb5e Add support for vmap of scatter where indices but not updates are batched. 2019-05-29 17:13:46 -04:00
Peter Hawkins
6e1ec38a14 Improve behavior of a number of math functions for extreme inputs.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.

Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
2019-05-29 12:51:24 -04:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00