87 Commits

Author SHA1 Message Date
Matthew Johnson
6dfe2d6e36 add numpy indexing batching tests 2019-02-11 09:30:21 -08:00
Matthew Johnson
b53eb241f7 gather passing all operand vmap tests 2019-02-11 09:30:13 -08:00
Matthew Johnson
b6cb3509cd progress on a gather vmap rule, PAIR=hawkinsp 2019-02-10 08:06:50 -08:00
Matthew Johnson
cde5c925fd start to sketch out gather batching rule (WIP) 2019-02-10 08:06:50 -08:00
Dougal Maclaurin
ce74bc55ce Handle closed-over tracers in while loop cond and body functions 2019-02-06 12:58:32 -05:00
Matthew Johnson
1636d058df fix lax.full handling of DeviceConstant scalars
fixes #330
2019-02-06 09:23:34 -08:00
Matthew Johnson
bf7a438c94 add more special cases of select batching rule 2019-02-03 14:00:51 -08:00
Matthew Johnson
44cffd0053
Merge pull request #310 from google/issue292
improve error messages for lax.slice/index funs
2019-02-03 13:48:11 -08:00
Matthew Johnson
583b654769 add an efficient special case to select batch rule 2019-02-03 10:01:06 -08:00
Matthew Johnson
5344e7aea0 add lax.select broadcasting tests, improve rule 2019-02-03 09:52:33 -08:00
Matthew Johnson
fe96c15d49 generalize select batch rule (fixes #311) 2019-02-03 09:27:03 -08:00
Matthew Johnson
0afb6202c9 improve error messages for lax.slice/index funs
c.f. #292
2019-02-02 21:41:06 -08:00
Matthew Johnson
055beb9037 Merge branch 'master' into pjit 2019-02-02 13:30:54 -08:00
Matthew Johnson
f5cffd722a delete more dead index_take code 2019-02-02 12:17:11 -08:00
Matthew Johnson
9f3060a0e6 index_take in terms of gather, delete index_untake
(c.f. #304)
2019-02-02 09:22:37 -08:00
Matthew Johnson
f69dda9641 fix merging issue 2019-02-01 17:39:49 -08:00
Matthew Johnson
08dc6994f5 partial progress 2019-02-01 17:05:49 -08:00
Peter Hawkins
5517347cc9 Reexpose reduce_window_shape_tuple since it has external users.
Fix accidental removal of rev() batching rule.
2019-02-01 16:29:53 -05:00
Peter Hawkins
09201c72bc Prefix most rules in lax module with underscores to improve generated doc readability.
Underscore-prefixed functions are automatically hidden from generated documentation. `lax` is a semi-public API, so this is a first step towards making its documentation useful.
2019-02-01 16:03:45 -05:00
Peter Hawkins
66c7a4248a
Merge pull request #303 from hawkinsp/minmax
Fix gradient for `np.amin` and `np.amax`.
2019-02-01 14:24:22 -05:00
Peter Hawkins
fb659e22b9 Fix gradient for np.amin and np.amax.
The JVP rule for `lax.reduce` depends on being able to identify the reducer as a monoid reducer. To get the correct behavior on complex numbers, `np.{amin,amax}` passed a non-standard reducer that compared complex numbers lexicographically as (real, imaginary) pairs. However, this prevented the gradient rule from identifying the reducer.

Instead, change the `lax.min` and `lax.max` to use the Numpy semantics when comparing complex numbers, and change `np.amin` and `np.amax` to use them.

Move the `np._broadcast_shapes` helper into `lax.py` as `lax.broadcast_shapes`.
2019-02-01 11:53:12 -05:00
Matthew Johnson
670f14a2ee
Merge pull request #300 from alexalemi/rev_batching
Rev batching
2019-02-01 07:29:13 -08:00
Alex Alemi
a9b221a1d3 Add batching rule for rev. 2019-01-31 21:47:05 -08:00
Peter Hawkins
26f85310e5 Implement np.{cumsum,cumprod,nancumsum,nancumprod}. 2019-01-31 18:56:06 -05:00
Matthew Johnson
c293b3775b add basic lax.stop_gradient primitive 2019-01-30 10:39:35 -08:00
Jonas Rauber
334581d5e9 fixed TypeError caused by body_fun of foreach loop 2019-01-29 17:31:10 +01:00
Matthew Johnson
4db50bc459
Merge pull request #288 from sschoenholz/patch-1
Added batching rules for convolutions + pooling.
2019-01-28 20:28:39 -08:00
sschoenholz
3b8d43cefa
Misc. small fixes. 2019-01-28 19:08:05 -08:00
Peter Hawkins
5dc15868ee Make translation rule for select_and_gather_add work even when --jax_enable_x64 is disabled.
Add support for constants whose types are not canonicalized by passing an optional flag to constant factories.

I am not entirely happy with the type canonicalization approach, but it seems good enough for this specific use case.
2019-01-28 18:41:27 -05:00
sschoenholz
a15bad401f
Added batching rules for convolutions + pooling.
Added batching rules:
conv_general_dilated_batch_rule
select_and_scatter_add_batch_rule
reduce_window_max_batch_rule
reduce_window_sum_batch_rule
2019-01-28 14:33:57 -08:00
Peter Hawkins
af69d341a7
Merge pull request #286 from hawkinsp/maxpool
Implement translation rule for select_and_gather_add (issue #274).
2019-01-28 15:44:03 -05:00
Peter Hawkins
9f84455fb2 Check for jax_enable_x64 in select_and_gather_add translation rule. 2019-01-28 15:10:58 -05:00
Peter Hawkins
f76134e460 Implement transpose rule for select_and_gather_add (issue #274).
There are a couple of caveats that mean that we shouldn't close the issue yet:
a) we need a jaxlib update to generalize the ReduceWindow support in the XLA/CPU backend.
b) jax_enable_x64 must be set, otherwise 64-bit types aren't available and bad things may happen. We should probably removed type-squashing from the JaxComputationBuilder class.
2019-01-28 14:29:17 -05:00
Matthew Johnson
c7ce442084 initial pmap/pxla code, pair-coded w/ @dougalm 2019-01-24 13:28:54 -08:00
Peter Hawkins
62d946123c Use complexfloating instead of complex to suppress NumPy warning.
Fixes #255.
2019-01-17 13:41:40 -05:00
Peter Hawkins
4792b9bed3
Merge pull request #231 from hawkinsp/complex
Add preliminary support for np.complex128.
2019-01-15 11:44:27 -05:00
Peter Hawkins
5fac477a8a Fix bug in scatter transpose rule.
Add some simple gather and scatter tests.
2019-01-14 14:33:40 -05:00
Peter Hawkins
39257b2442 Work on scatter JVP/transpose. 2019-01-14 10:28:35 -05:00
Peter Hawkins
9812bea1ef Merge remote-tracking branch 'google/master' into gather_scatter 2019-01-14 08:24:01 -05:00
Matthew Johnson
54886bd310 add sort_key_val batching rule (fixes #221) 2019-01-13 11:56:36 -08:00
Matthew Johnson
5a4713f108 add tests for np.sort (c.f. #221) 2019-01-13 11:56:36 -08:00
Peter Hawkins
d43c65dcd8 Add preliminary support for np.complex128.
Only lightly tested.
2019-01-11 18:22:43 -05:00
Peter Hawkins
65efd45fc9 Test more Numpy ops for complex types.
Fix a number of ops that did not handle complex numbers the same way as regular numpy.
2019-01-11 14:49:42 -05:00
Peter Hawkins
adaa344400 More progress on scatter/gather. 2019-01-09 10:58:44 -05:00
Peter Hawkins
5e48420f12 First attempt at lax.gather and lax.scatter. 2019-01-08 21:34:48 -05:00
Matthew Johnson
280b3fe2fc python3 likes list(map(...)) 2019-01-07 16:54:51 -08:00
Matthew Johnson
47eb8fa17c add another concreteness check to lax.iota 2019-01-07 12:38:46 -08:00
Matthew Johnson
df87d5ce43 make lax.full require concrete shapes
improves error message for #204
2019-01-07 12:31:01 -08:00
Matthew Johnson
0f7c7c4eab generalize jacfwd and jacrev to handle pytrees 2019-01-06 12:49:41 -08:00
Matthew Johnson
ad4322c5da playing around with flattening functions 2019-01-06 12:49:35 -08:00