41 Commits

Author SHA1 Message Date
Matthew Johnson
a6e374ddb9
Merge pull request #858 from google/improve-conv-batching
improve efficiency of conv batching rules (i.e. vmap rules)
2019-06-17 19:23:07 -07:00
Peter Hawkins
0190684ee2
Merge pull request #866 from hawkinsp/master
Implement np.ix_, for non-bool inputs.
2019-06-17 18:39:45 -06:00
Peter Hawkins
ec685bf8ae Implement np.ix_, for non-bool inputs. 2019-06-17 17:08:27 -04:00
Matthew Johnson
a1219f10e5
Merge pull request #865 from google/fix-reshape-grad-bug
fix special-form reshape transpose bug (and add tests)
2019-06-17 12:29:45 -07:00
Matthew Johnson
fef68deef6 fix reshape transpose bug (and add tests)
This version of reshape (taking a `dimensions` argument, which
effectively fuses in a transpose) seems only to be used in the JVP rule
for lax._reduce_prod (basically np.product), but its transpose rule was
totally busted and untested.
2019-06-17 11:54:36 -07:00
Peter Hawkins
129e73a258 Use lax.full to create zeros array in gather transpose rule. 2019-06-17 08:19:36 -04:00
Matthew Johnson
ff29d582e8 t # This is a combination of 2 commits.
make all conv vmap rules generate a single call

also plumb feature_group_count and batch_group_count everywhere
2019-06-16 07:50:19 -07:00
Matthew Johnson
1262ca9b30 improve conv rhs batching, add systematic test 2019-06-15 12:01:20 -07:00
James Bradbury
edfe52035b Fix dot batch rule bug and add test + check 2019-06-12 18:02:01 -07:00
Peter Hawkins
bd389b7fcf Add bug number for integer dots. 2019-06-06 19:04:11 -04:00
Peter Hawkins
15361c4dde Add support for integer dot operations.
Lower to a sum of products for integers since XLA currently lacks support for integer dots.
2019-06-06 17:24:43 -04:00
Matthew Johnson
1feefd10ac
Merge pull request #820 from google/simplify-gather-shape-rule
simplify gather shape rule
2019-06-05 19:14:27 -07:00
Matthew Johnson
9e49500c54 simplify gather shape rule
Co-authored-by: Roy Frostig <frostig@google.com>
2019-06-05 17:06:42 -07:00
James Bradbury
9bc5f2aecf fix bug in dot batching rule 2019-06-05 15:17:06 -07:00
Peter Hawkins
e63bd4d592 Add domain test to atanh implementation.
Disable some tests on tpu.
2019-05-31 17:22:19 -04:00
Matthew Johnson
25262edb92 Make pmap lax.psum(1, 'i') and pxla.axis_index('i') work
The implementation mechanism is to use a bit of dynamic context to model
the axis name environment at trace time, and for the environment to
track how an axis name maps to an axis size and the corresponding trace
(i.e. the JaxprTrace instance). With that information, we can lift
special primitives that take axis_name parameters into the trace as
needed without having a data dependence on the input.
2019-05-29 20:13:07 -07:00
Peter Hawkins
0293ecbb5e Add support for vmap of scatter where indices but not updates are batched. 2019-05-29 17:13:46 -04:00
Peter Hawkins
6e1ec38a14 Improve behavior of a number of math functions for extreme inputs.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.

Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
2019-05-29 12:51:24 -04:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00
Peter Hawkins
cfdf1cd3e9 Propagate symbolic zeros instead of instantiating them. 2019-05-28 15:41:27 -04:00
Peter Hawkins
7ee59e96d3 Instantiate symbolic zeros in the scatter-add transpose rule.
Fixes #776.
2019-05-28 10:30:58 -04:00
Peter Hawkins
9e68d9114e Use more numerically stable formulation of tanh gradient. 2019-05-25 09:59:10 -04:00
Matthew Johnson
ec6e39b9b2 fix negative index handling in index_take
fixes #751
2019-05-21 19:18:29 -07:00
Matthew Johnson
f8aa563db1 make jax.numpy.array(3) give 0D array, not scalar
the mechanism is to use lax.reshape (which was already there) and avoid
the optimization that skipped actually calling reshape_p.bind

fixes #121
2019-05-20 11:49:09 -07:00
Matthew Johnson
42a1ad4307 change dtype promotion behavior for jit-invariance
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`

The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)

Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.

In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
2019-05-19 18:49:16 -07:00
Matthew Johnson
8f9e4b1260 BroadcastedIota needs integer type (fixes #728) 2019-05-17 12:46:11 -07:00
Matthew Johnson
b1fd8e6eb6 add test for DeviceConstant repr 2019-05-17 12:38:45 -07:00
Peter Hawkins
367833bea2 Changes for compatibility with a upcoming Jaxlib update.
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
2019-05-08 20:32:24 -04:00
Peter Hawkins
68f2cb4491 Implement JVP rule for reduce_prod().
This is sufficient to compute first-order derivatives of a product reduction (although not second-order derivatives because there is no JVP for reduce-window-prod).
2019-05-05 14:37:25 -04:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Peter Hawkins
6b4c74b182 Add batching rule for dynamic_update_slice. 2019-04-30 11:48:53 -04:00
Matthew Johnson
076dd0fd99
Merge pull request #636 from google/custom-vjps
add support for custom VJP definitions
2019-04-23 18:46:47 -07:00
Matthew Johnson
85755820bb add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
2019-04-23 17:47:28 -07:00
James Bradbury
55d74d8624 add VJP for lax._select_and_gather_add (3rd-order grad of maxpool) 2019-04-20 17:06:56 -07:00
James Bradbury
b940245730 add VJP for lax._select_and_scatter_add (2nd-order grad of maxpool) 2019-04-20 17:06:35 -07:00
Yutong Zhao
bbf0d5c55e Add brcast to deal with inconsistent shapes. 2019-04-17 20:54:01 -04:00
Yutong Zhao
dfd3d93350 Fix bug in order of y,x in grad of atan2. 2019-04-17 19:57:42 -04:00
Yutong Zhao
36e5ec2189 Add jax tests and fix style. 2019-04-17 19:53:06 -04:00
Yutong Zhao
60539a2612 Implement jvp for atan2 2019-04-17 18:52:22 -04:00
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00