114 Commits

Author SHA1 Message Date
Peter Hawkins
4e6523dbae
Merge pull request #408 from hawkinsp/master
Add some doc strings to lax primitives.
2019-02-19 14:13:17 -05:00
Peter Hawkins
81d43ee083 Use neg(t) rather than -t in sub_p transpose rule. 2019-02-19 11:59:37 -05:00
Peter Hawkins
8009bf8ded Add transpose rule for sub_p. 2019-02-19 11:46:22 -05:00
Peter Hawkins
c5aa87f4f1 Add some doc strings to lax primitives.
Since lax is a semipublic API, its public methods need at least minimal documentation. Many of the docstrings added in this PR are somewhat redundant, but at least a few contain useful information, and the documentation reads better with at least some minimal text for each function.

Hide some methods that shouldn't be public from the lax API docs.
2019-02-19 11:30:31 -05:00
Peter Hawkins
97fb6f19b1 Fix formatting of lax.while_loop and lax.fori_loop doc comments. 2019-02-18 16:25:12 -05:00
Matthew Johnson
13834ee4f5 add "yet" to while_loop rev-autodiff statement 2019-02-18 12:58:35 -08:00
Matthew Johnson
5639876405 fix typo 2019-02-18 12:44:36 -08:00
Matthew Johnson
042a20d2da improve loop construct docs, remove foreach_loop 2019-02-18 12:41:07 -08:00
Matthew Johnson
60865a5bb5 fix broken dot batch rule case 2019-02-17 09:34:49 -08:00
Matthew Johnson
3bf5f3326f
Merge pull request #383 from google/issue347
fix nan handling in pow jvp (fixes #347)
2019-02-16 08:10:18 -08:00
Matthew Johnson
6a9b741ebc add comment in pow_jvp_lhs about calling _safe_mul 2019-02-16 08:08:04 -08:00
Matthew Johnson
58749c0a13 add lax._safe_mul with 0*inf=0, used in pow jvp 2019-02-15 18:33:24 -08:00
Matthew Johnson
1cbf49a404 wip 2019-02-15 15:44:49 -08:00
Peter Hawkins
2292681128 Fix dimension numbers in LHS transpose rule for conv_general_dilated.
Fixes #380.
2019-02-15 12:54:02 -05:00
Matthew Johnson
98dcf264e9 fix nan handling in pow jvp (fixes #347) 2019-02-15 07:04:57 -08:00
Matthew Johnson
2ee457ecab
Merge pull request #369 from google/loop-improvements
add random.fold_in, update mnist_vae.py w/ loop improvements
2019-02-13 10:20:03 -08:00
Matthew Johnson
89dc3eb88e rename lax._while_loop -> lax.while_loop 2019-02-13 09:56:53 -08:00
Matthew Johnson
78fd9e1a10 debug cholesky grad, remove stale dot_general check 2019-02-13 09:18:28 -08:00
Matthew Johnson
8df660e9ea use more _const and _constant_like helpers 2019-02-13 08:25:11 -08:00
Matthew Johnson
cad7db762b improve numpy dtype promo logic on Python scalars 2019-02-13 08:06:37 -08:00
Matthew Johnson
adaea811fc fix transpose batching rule bug, add tests 2019-02-12 07:26:32 -08:00
Matthew Johnson
f8b48cb1c4 add comment explaining scatter batching rule logic 2019-02-11 12:46:17 -08:00
Matthew Johnson
e3b9df14a8 complete scatter batching rule 2019-02-11 11:40:08 -08:00
Matthew Johnson
65c023d231 start adding scatter batching rule 2019-02-11 11:36:09 -08:00
Matthew Johnson
90d92a5a5c fix gather batching rule bug 2019-02-11 11:30:44 -08:00
Matthew Johnson
d5ee720aea more testing of gather batching rule 2019-02-11 11:21:29 -08:00
Matthew Johnson
cccc0304fd finish gather batching rule, pair w/ @hawkinsp 2019-02-11 09:30:24 -08:00
Matthew Johnson
6dfe2d6e36 add numpy indexing batching tests 2019-02-11 09:30:21 -08:00
Matthew Johnson
b53eb241f7 gather passing all operand vmap tests 2019-02-11 09:30:13 -08:00
Matthew Johnson
b6cb3509cd progress on a gather vmap rule, PAIR=hawkinsp 2019-02-10 08:06:50 -08:00
Matthew Johnson
cde5c925fd start to sketch out gather batching rule (WIP) 2019-02-10 08:06:50 -08:00
Dougal Maclaurin
ce74bc55ce Handle closed-over tracers in while loop cond and body functions 2019-02-06 12:58:32 -05:00
Matthew Johnson
1636d058df fix lax.full handling of DeviceConstant scalars
fixes #330
2019-02-06 09:23:34 -08:00
Matthew Johnson
bf7a438c94 add more special cases of select batching rule 2019-02-03 14:00:51 -08:00
Matthew Johnson
44cffd0053
Merge pull request #310 from google/issue292
improve error messages for lax.slice/index funs
2019-02-03 13:48:11 -08:00
Matthew Johnson
583b654769 add an efficient special case to select batch rule 2019-02-03 10:01:06 -08:00
Matthew Johnson
5344e7aea0 add lax.select broadcasting tests, improve rule 2019-02-03 09:52:33 -08:00
Matthew Johnson
fe96c15d49 generalize select batch rule (fixes #311) 2019-02-03 09:27:03 -08:00
Matthew Johnson
0afb6202c9 improve error messages for lax.slice/index funs
c.f. #292
2019-02-02 21:41:06 -08:00
Matthew Johnson
055beb9037 Merge branch 'master' into pjit 2019-02-02 13:30:54 -08:00
Matthew Johnson
f5cffd722a delete more dead index_take code 2019-02-02 12:17:11 -08:00
Matthew Johnson
9f3060a0e6 index_take in terms of gather, delete index_untake
(c.f. #304)
2019-02-02 09:22:37 -08:00
Matthew Johnson
f69dda9641 fix merging issue 2019-02-01 17:39:49 -08:00
Matthew Johnson
08dc6994f5 partial progress 2019-02-01 17:05:49 -08:00
Peter Hawkins
5517347cc9 Reexpose reduce_window_shape_tuple since it has external users.
Fix accidental removal of rev() batching rule.
2019-02-01 16:29:53 -05:00
Peter Hawkins
09201c72bc Prefix most rules in lax module with underscores to improve generated doc readability.
Underscore-prefixed functions are automatically hidden from generated documentation. `lax` is a semi-public API, so this is a first step towards making its documentation useful.
2019-02-01 16:03:45 -05:00
Peter Hawkins
66c7a4248a
Merge pull request #303 from hawkinsp/minmax
Fix gradient for `np.amin` and `np.amax`.
2019-02-01 14:24:22 -05:00
Peter Hawkins
fb659e22b9 Fix gradient for np.amin and np.amax.
The JVP rule for `lax.reduce` depends on being able to identify the reducer as a monoid reducer. To get the correct behavior on complex numbers, `np.{amin,amax}` passed a non-standard reducer that compared complex numbers lexicographically as (real, imaginary) pairs. However, this prevented the gradient rule from identifying the reducer.

Instead, change the `lax.min` and `lax.max` to use the Numpy semantics when comparing complex numbers, and change `np.amin` and `np.amax` to use them.

Move the `np._broadcast_shapes` helper into `lax.py` as `lax.broadcast_shapes`.
2019-02-01 11:53:12 -05:00
Matthew Johnson
670f14a2ee
Merge pull request #300 from alexalemi/rev_batching
Rev batching
2019-02-01 07:29:13 -08:00
Alex Alemi
a9b221a1d3 Add batching rule for rev. 2019-01-31 21:47:05 -08:00