Matthew Johnson
6dfe2d6e36
add numpy indexing batching tests
2019-02-11 09:30:21 -08:00
Matthew Johnson
b53eb241f7
gather passing all operand vmap tests
2019-02-11 09:30:13 -08:00
Matthew Johnson
b6cb3509cd
progress on a gather vmap rule, PAIR=hawkinsp
2019-02-10 08:06:50 -08:00
Matthew Johnson
cde5c925fd
start to sketch out gather batching rule (WIP)
2019-02-10 08:06:50 -08:00
Dougal Maclaurin
ce74bc55ce
Handle closed-over tracers in while loop cond and body functions
2019-02-06 12:58:32 -05:00
Matthew Johnson
1636d058df
fix lax.full handling of DeviceConstant scalars
...
fixes #330
2019-02-06 09:23:34 -08:00
Matthew Johnson
bf7a438c94
add more special cases of select batching rule
2019-02-03 14:00:51 -08:00
Matthew Johnson
44cffd0053
Merge pull request #310 from google/issue292
...
improve error messages for lax.slice/index funs
2019-02-03 13:48:11 -08:00
Matthew Johnson
583b654769
add an efficient special case to select batch rule
2019-02-03 10:01:06 -08:00
Matthew Johnson
5344e7aea0
add lax.select broadcasting tests, improve rule
2019-02-03 09:52:33 -08:00
Matthew Johnson
fe96c15d49
generalize select batch rule ( fixes #311 )
2019-02-03 09:27:03 -08:00
Matthew Johnson
0afb6202c9
improve error messages for lax.slice/index funs
...
c.f. #292
2019-02-02 21:41:06 -08:00
Matthew Johnson
055beb9037
Merge branch 'master' into pjit
2019-02-02 13:30:54 -08:00
Matthew Johnson
f5cffd722a
delete more dead index_take code
2019-02-02 12:17:11 -08:00
Matthew Johnson
9f3060a0e6
index_take in terms of gather, delete index_untake
...
(c.f. #304 )
2019-02-02 09:22:37 -08:00
Matthew Johnson
f69dda9641
fix merging issue
2019-02-01 17:39:49 -08:00
Matthew Johnson
08dc6994f5
partial progress
2019-02-01 17:05:49 -08:00
Peter Hawkins
5517347cc9
Reexpose reduce_window_shape_tuple since it has external users.
...
Fix accidental removal of rev() batching rule.
2019-02-01 16:29:53 -05:00
Peter Hawkins
09201c72bc
Prefix most rules in lax module with underscores to improve generated doc readability.
...
Underscore-prefixed functions are automatically hidden from generated documentation. `lax` is a semi-public API, so this is a first step towards making its documentation useful.
2019-02-01 16:03:45 -05:00
Peter Hawkins
66c7a4248a
Merge pull request #303 from hawkinsp/minmax
...
Fix gradient for `np.amin` and `np.amax`.
2019-02-01 14:24:22 -05:00
Peter Hawkins
fb659e22b9
Fix gradient for np.amin
and np.amax
.
...
The JVP rule for `lax.reduce` depends on being able to identify the reducer as a monoid reducer. To get the correct behavior on complex numbers, `np.{amin,amax}` passed a non-standard reducer that compared complex numbers lexicographically as (real, imaginary) pairs. However, this prevented the gradient rule from identifying the reducer.
Instead, change the `lax.min` and `lax.max` to use the Numpy semantics when comparing complex numbers, and change `np.amin` and `np.amax` to use them.
Move the `np._broadcast_shapes` helper into `lax.py` as `lax.broadcast_shapes`.
2019-02-01 11:53:12 -05:00
Matthew Johnson
670f14a2ee
Merge pull request #300 from alexalemi/rev_batching
...
Rev batching
2019-02-01 07:29:13 -08:00
Alex Alemi
a9b221a1d3
Add batching rule for rev.
2019-01-31 21:47:05 -08:00
Peter Hawkins
26f85310e5
Implement np.{cumsum,cumprod,nancumsum,nancumprod}.
2019-01-31 18:56:06 -05:00
Matthew Johnson
c293b3775b
add basic lax.stop_gradient primitive
2019-01-30 10:39:35 -08:00
Jonas Rauber
334581d5e9
fixed TypeError caused by body_fun of foreach loop
2019-01-29 17:31:10 +01:00
Matthew Johnson
4db50bc459
Merge pull request #288 from sschoenholz/patch-1
...
Added batching rules for convolutions + pooling.
2019-01-28 20:28:39 -08:00
sschoenholz
3b8d43cefa
Misc. small fixes.
2019-01-28 19:08:05 -08:00
Peter Hawkins
5dc15868ee
Make translation rule for select_and_gather_add work even when --jax_enable_x64 is disabled.
...
Add support for constants whose types are not canonicalized by passing an optional flag to constant factories.
I am not entirely happy with the type canonicalization approach, but it seems good enough for this specific use case.
2019-01-28 18:41:27 -05:00
sschoenholz
a15bad401f
Added batching rules for convolutions + pooling.
...
Added batching rules:
conv_general_dilated_batch_rule
select_and_scatter_add_batch_rule
reduce_window_max_batch_rule
reduce_window_sum_batch_rule
2019-01-28 14:33:57 -08:00
Peter Hawkins
af69d341a7
Merge pull request #286 from hawkinsp/maxpool
...
Implement translation rule for select_and_gather_add (issue #274 ).
2019-01-28 15:44:03 -05:00
Peter Hawkins
9f84455fb2
Check for jax_enable_x64 in select_and_gather_add translation rule.
2019-01-28 15:10:58 -05:00
Peter Hawkins
f76134e460
Implement transpose rule for select_and_gather_add (issue #274 ).
...
There are a couple of caveats that mean that we shouldn't close the issue yet:
a) we need a jaxlib update to generalize the ReduceWindow support in the XLA/CPU backend.
b) jax_enable_x64 must be set, otherwise 64-bit types aren't available and bad things may happen. We should probably removed type-squashing from the JaxComputationBuilder class.
2019-01-28 14:29:17 -05:00
Matthew Johnson
c7ce442084
initial pmap/pxla code, pair-coded w/ @dougalm
2019-01-24 13:28:54 -08:00
Peter Hawkins
62d946123c
Use complexfloating instead of complex to suppress NumPy warning.
...
Fixes #255 .
2019-01-17 13:41:40 -05:00
Peter Hawkins
4792b9bed3
Merge pull request #231 from hawkinsp/complex
...
Add preliminary support for np.complex128.
2019-01-15 11:44:27 -05:00
Peter Hawkins
5fac477a8a
Fix bug in scatter transpose rule.
...
Add some simple gather and scatter tests.
2019-01-14 14:33:40 -05:00
Peter Hawkins
39257b2442
Work on scatter JVP/transpose.
2019-01-14 10:28:35 -05:00
Peter Hawkins
9812bea1ef
Merge remote-tracking branch 'google/master' into gather_scatter
2019-01-14 08:24:01 -05:00
Matthew Johnson
54886bd310
add sort_key_val batching rule ( fixes #221 )
2019-01-13 11:56:36 -08:00
Matthew Johnson
5a4713f108
add tests for np.sort (c.f. #221 )
2019-01-13 11:56:36 -08:00
Peter Hawkins
d43c65dcd8
Add preliminary support for np.complex128.
...
Only lightly tested.
2019-01-11 18:22:43 -05:00
Peter Hawkins
65efd45fc9
Test more Numpy ops for complex types.
...
Fix a number of ops that did not handle complex numbers the same way as regular numpy.
2019-01-11 14:49:42 -05:00
Peter Hawkins
adaa344400
More progress on scatter/gather.
2019-01-09 10:58:44 -05:00
Peter Hawkins
5e48420f12
First attempt at lax.gather and lax.scatter.
2019-01-08 21:34:48 -05:00
Matthew Johnson
280b3fe2fc
python3 likes list(map(...))
2019-01-07 16:54:51 -08:00
Matthew Johnson
47eb8fa17c
add another concreteness check to lax.iota
2019-01-07 12:38:46 -08:00
Matthew Johnson
df87d5ce43
make lax.full require concrete shapes
...
improves error message for #204
2019-01-07 12:31:01 -08:00
Matthew Johnson
0f7c7c4eab
generalize jacfwd and jacrev to handle pytrees
2019-01-06 12:49:41 -08:00
Matthew Johnson
ad4322c5da
playing around with flattening functions
2019-01-06 12:49:35 -08:00