93 Commits

Author SHA1 Message Date
Peter Hawkins
510a9167c5 Add C++ implementation of pytree logic.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.
2019-07-29 15:06:05 -04:00
Peter Hawkins
2369d1fe61 Increase minimum Jaxlib version to 0.1.22.
Remove code that preserves backward compatibility with older jaxlib versions.
2019-07-23 21:45:41 -04:00
Peter Hawkins
1479ae9066 Add a common lax._canonicalize_shape method, use on methods that accept shapes in lax.
Explicitly convert shape entries to integers using the Python __index__() method.
Implement __index__ on DeviceArrays so shapes like (1, DeviceArray(2)) work.

Fixes bug where np.full accepted floating point shapes; __index__() errors for non-integer inputs, where int() would silently cast and drop information.
2019-07-23 16:19:02 -04:00
Peter Hawkins
f64332b394 Remove assertions in scatter/dynamic_update_slice JVP rules that test whether index tangents are symbolically zero.
Since indices are integers, their tangents should be zero anyway, and symbolic zeros should always be treated as an optimization rather than a necessary precondition.
2019-07-23 14:18:07 -04:00
Peter Hawkins
0850318a83 Add support for mixing basic and advanced indexing in the same scatter operation. 2019-07-14 11:55:26 -04:00
Peter Hawkins
05ff396716 Add batching rule for reduce_window_p. Allows batching of np.cumprod. 2019-07-13 10:22:26 -04:00
Matthew Johnson
79668ae4ed fix reduce_window batching rule 2019-07-06 11:58:33 -07:00
Matthew Johnson
ddf7f69cad fix seleect broadcasting rule 2019-07-06 11:52:24 -07:00
Matthew Johnson
febad2d863 fix broadcast_in_dim batching rule 2019-07-06 11:47:50 -07:00
Matthew Johnson
ccb1760f49 add a lot of systematic vmap tests 2019-07-06 11:28:15 -07:00
Matthew Johnson
a5e86ae128 enable soft_pmap device persistence
Previously soft_pmap didn't allow for sharded device persistence because
it performs reshapes on the input and output of the underlying pmap
computation corrseponding to splitting out and merging together the
hardware-mapped and software-mapped axes, resepectively. These reshapes
forced the ShardedDeviceArray produced by the pmap computation to be
collected into a (single-device-backed) DeviceArray.

The approach in this commit is to make reshape smarter about
ShardedDeviceArrays so that axis-merging logical reshapes don't force
collection (i.e. don't force re-layout). Instead they now produce a new
ShardedDeviceArray subclass called a ChunkedDeviceArray, which
represents the same logical reshape result but without data movement.

One way to think about the key difference between ShardedDeviceArray and
ChunkedDeviceArray is that when forced the former collects its shards
together using onp.stack while the latter collects its shards with
onp.concatenate. The leading letter of each name is meant to remind us
of that difference (s for stack, c for concatenate).

ChunkedDeviceArrays can be turned back into ShardedDeviceArrays under
particular reshapes, namely reshapes that split the hardware-mapped axis
back out into the leading dimension. This way a sequence of soft_pmapped
computations can maintain device persistence (i.e. not force collection).
Every other operation forces collcetion, just like it does for
ShardedDeviceArrays.
2019-07-06 10:21:59 -07:00
Matthew Johnson
db52d42597 also fix lax.complex jvp, enable test 2019-07-05 14:39:32 -07:00
Matthew Johnson
93841df822 fix lax.imag jvp and enable test, fixes #979 2019-07-05 14:32:04 -07:00
Peter Hawkins
a06ba06f97 Update comments. 2019-07-02 13:23:05 -04:00
Peter Hawkins
165df6204b Simplify reduce-precision logic.
Enable TPU gradient tests only up to order 1. The first-order JVP of reduce-window tests select_and_scatter_add, which is the part changed by this PR.
2019-07-02 11:34:49 -04:00
Peter Hawkins
40560d2c9a Refactor select_and_gather_add implementation to improve readability.
Change implementation to use ReducePrecision to perform half-word reductions.
2019-07-01 22:26:36 -04:00
Peter Hawkins
db369091a2 Add support for higher derivatives of reduce-window-min/max at reduced precision. On CPU/GPU this means support for float64 derivatives, and on TPU this means support for float32 derivatives.
Warn if we are forced to be imprecise.
2019-06-28 20:27:10 -04:00
Peter Hawkins
3e914e17b0 Improve documentation for precision. 2019-06-28 14:06:24 -04:00
Peter Hawkins
bca27fea8b Simplify precision specification: only allow a single precision for an entire operator. 2019-06-28 12:48:44 -04:00
Peter Hawkins
0af9da7662 Add precision option to lax dot and conv APIs.
Set a default precision of "highest" in LU decomposition.
Enable a number of dot and conv tests on TPU under highest precision.
Enable linalg tests that use LU decomposition on TPU.
2019-06-28 10:00:39 -04:00
Peter Hawkins
3b4521b1f6 Enable convolutions for non float32 types. 2019-06-27 17:17:49 -04:00
Peter Hawkins
990c2df123 Implement a pure Python LU decomposition that can be used on platforms where we do not otherwise have a better implementation.
Restructure xla.lower_fun and trace_unwrapped_to_jaxpr so the instantiate option can be passed to them, separately from any function arguments.
2019-06-27 14:50:29 -04:00
Peter Hawkins
c7afa1eb34
Merge pull request #930 from hawkinsp/master
Merge reduce-window batching rules. Add rule for reduce_window_min.
2019-06-26 13:40:12 -04:00
Peter Hawkins
755d2818e8 Merge reduce-window batching rules. Add batching rule for reduce_window_min. 2019-06-26 10:19:42 -04:00
Matthew Johnson
a663486148 fixes from rebase onto master 2019-06-24 19:45:18 -07:00
Matthew Johnson
40bbf068d8 fix broadcasting papply rule, move to lax_parallel 2019-06-24 19:39:12 -07:00
Matthew Johnson
d64188bcb6 del serial_pmap, simpler papply, add parallelize
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.

The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.

This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.

Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
  doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
  broadcasting_papply rule and plumbing the `size` argument to papply
  rules
* remove psplit primitive and psplit_like primitives and replace it with
  calls to all_to_all where needed
2019-06-24 19:38:26 -07:00
Roy Frostig
33b01733a9
Merge pull request #916 from google/parallelize
parallelization work-in-progress
2019-06-24 16:08:14 -07:00
Roy Frostig
15bc966567 Merge branch 'master' into parallelize 2019-06-24 11:32:59 -07:00
Peter Hawkins
c1bec691c5 Avoid instantiating zeros in dynamic_slice/gather transpose rules. 2019-06-24 13:44:49 -04:00
Justin Lebar
d5ba04b79e Add jax.ops.index_min/max.
These are analogous to index_add.
2019-06-21 19:33:34 -07:00
John Schulman
8b163628fd map stop_gradient over data structure. otherwise it is silently a no-op 2019-06-20 16:23:13 -07:00
Matthew Johnson
8bc4e379f5 make DeviceArray.__hash__ raise an error
Fixes #883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.

Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
2019-06-19 10:12:13 -07:00
Peter Hawkins
04676e4e95 Use _dtype instead of .dtype in dynamic slice transpose rule. 2019-06-18 10:25:10 -04:00
Peter Hawkins
c8b946e1fd Use lax.full instead of broadcast in transpose rule for dynamic_slice. 2019-06-18 10:18:33 -04:00
Peter Hawkins
b4acfe0640
Merge pull request #868 from hawkinsp/master
Implement batching for np.linalg.solve
2019-06-18 05:58:10 -06:00
Matthew Johnson
6dd5423666 manually fuse a transpose into a reshape 2019-06-17 19:39:14 -07:00
Matthew Johnson
3ce87f7874
Merge pull request #859 from jheek/grouped-convs
Grouped convs
2019-06-17 19:25:19 -07:00
Matthew Johnson
a6e374ddb9
Merge pull request #858 from google/improve-conv-batching
improve efficiency of conv batching rules (i.e. vmap rules)
2019-06-17 19:23:07 -07:00
Peter Hawkins
991a5a9f4c Use _dtype instead of .dtype in dynamic slice rule. 2019-06-17 21:39:34 -04:00
Peter Hawkins
4e872b96b9 Fix type mismatch with int32-type indices under a jit with 64-bit types enabled. 2019-06-17 21:18:27 -04:00
Peter Hawkins
0190684ee2
Merge pull request #866 from hawkinsp/master
Implement np.ix_, for non-bool inputs.
2019-06-17 18:39:45 -06:00
Peter Hawkins
ec685bf8ae Implement np.ix_, for non-bool inputs. 2019-06-17 17:08:27 -04:00
Jonathan Heek
e3462fd8b1 write out batch_feature_groups to simplify and correct implementation 2019-06-17 21:32:03 +02:00
Matthew Johnson
a1219f10e5
Merge pull request #865 from google/fix-reshape-grad-bug
fix special-form reshape transpose bug (and add tests)
2019-06-17 12:29:45 -07:00
Matthew Johnson
fef68deef6 fix reshape transpose bug (and add tests)
This version of reshape (taking a `dimensions` argument, which
effectively fuses in a transpose) seems only to be used in the JVP rule
for lax._reduce_prod (basically np.product), but its transpose rule was
totally busted and untested.
2019-06-17 11:54:36 -07:00
Roy Frostig
a113d9f5c9 Merge branch 'master' into parallelize 2019-06-17 09:09:39 -07:00
Peter Hawkins
129e73a258 Use lax.full to create zeros array in gather transpose rule. 2019-06-17 08:19:36 -04:00
Jonathan Heek
077d56529f grouped convolution support 2019-06-17 12:37:19 +02:00
Matthew Johnson
ff29d582e8 t # This is a combination of 2 commits.
make all conv vmap rules generate a single call

also plumb feature_group_count and batch_group_count everywhere
2019-06-16 07:50:19 -07:00