31 Commits

Author SHA1 Message Date
Peter Hawkins
15361c4dde Add support for integer dot operations.
Lower to a sum of products for integers since XLA currently lacks support for integer dots.
2019-06-06 17:24:43 -04:00
Matthew Johnson
1feefd10ac
Merge pull request #820 from google/simplify-gather-shape-rule
simplify gather shape rule
2019-06-05 19:14:27 -07:00
Matthew Johnson
9e49500c54 simplify gather shape rule
Co-authored-by: Roy Frostig <frostig@google.com>
2019-06-05 17:06:42 -07:00
James Bradbury
9bc5f2aecf fix bug in dot batching rule 2019-06-05 15:17:06 -07:00
Peter Hawkins
e63bd4d592 Add domain test to atanh implementation.
Disable some tests on tpu.
2019-05-31 17:22:19 -04:00
Matthew Johnson
25262edb92 Make pmap lax.psum(1, 'i') and pxla.axis_index('i') work
The implementation mechanism is to use a bit of dynamic context to model
the axis name environment at trace time, and for the environment to
track how an axis name maps to an axis size and the corresponding trace
(i.e. the JaxprTrace instance). With that information, we can lift
special primitives that take axis_name parameters into the trace as
needed without having a data dependence on the input.
2019-05-29 20:13:07 -07:00
Peter Hawkins
0293ecbb5e Add support for vmap of scatter where indices but not updates are batched. 2019-05-29 17:13:46 -04:00
Peter Hawkins
6e1ec38a14 Improve behavior of a number of math functions for extreme inputs.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.

Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
2019-05-29 12:51:24 -04:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00
Peter Hawkins
cfdf1cd3e9 Propagate symbolic zeros instead of instantiating them. 2019-05-28 15:41:27 -04:00
Peter Hawkins
7ee59e96d3 Instantiate symbolic zeros in the scatter-add transpose rule.
Fixes #776.
2019-05-28 10:30:58 -04:00
Peter Hawkins
9e68d9114e Use more numerically stable formulation of tanh gradient. 2019-05-25 09:59:10 -04:00
Matthew Johnson
ec6e39b9b2 fix negative index handling in index_take
fixes #751
2019-05-21 19:18:29 -07:00
Matthew Johnson
f8aa563db1 make jax.numpy.array(3) give 0D array, not scalar
the mechanism is to use lax.reshape (which was already there) and avoid
the optimization that skipped actually calling reshape_p.bind

fixes #121
2019-05-20 11:49:09 -07:00
Matthew Johnson
42a1ad4307 change dtype promotion behavior for jit-invariance
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`

The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)

Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.

In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
2019-05-19 18:49:16 -07:00
Matthew Johnson
8f9e4b1260 BroadcastedIota needs integer type (fixes #728) 2019-05-17 12:46:11 -07:00
Matthew Johnson
b1fd8e6eb6 add test for DeviceConstant repr 2019-05-17 12:38:45 -07:00
Peter Hawkins
367833bea2 Changes for compatibility with a upcoming Jaxlib update.
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
2019-05-08 20:32:24 -04:00
Peter Hawkins
68f2cb4491 Implement JVP rule for reduce_prod().
This is sufficient to compute first-order derivatives of a product reduction (although not second-order derivatives because there is no JVP for reduce-window-prod).
2019-05-05 14:37:25 -04:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Peter Hawkins
6b4c74b182 Add batching rule for dynamic_update_slice. 2019-04-30 11:48:53 -04:00
Matthew Johnson
076dd0fd99
Merge pull request #636 from google/custom-vjps
add support for custom VJP definitions
2019-04-23 18:46:47 -07:00
Matthew Johnson
85755820bb add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
2019-04-23 17:47:28 -07:00
James Bradbury
55d74d8624 add VJP for lax._select_and_gather_add (3rd-order grad of maxpool) 2019-04-20 17:06:56 -07:00
James Bradbury
b940245730 add VJP for lax._select_and_scatter_add (2nd-order grad of maxpool) 2019-04-20 17:06:35 -07:00
Yutong Zhao
bbf0d5c55e Add brcast to deal with inconsistent shapes. 2019-04-17 20:54:01 -04:00
Yutong Zhao
dfd3d93350 Fix bug in order of y,x in grad of atan2. 2019-04-17 19:57:42 -04:00
Yutong Zhao
36e5ec2189 Add jax tests and fix style. 2019-04-17 19:53:06 -04:00
Yutong Zhao
60539a2612 Implement jvp for atan2 2019-04-17 18:52:22 -04:00
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00