24 Commits

Author SHA1 Message Date
Peter Hawkins
6e1ec38a14 Improve behavior of a number of math functions for extreme inputs.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.

Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
2019-05-29 12:51:24 -04:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00
Peter Hawkins
cfdf1cd3e9 Propagate symbolic zeros instead of instantiating them. 2019-05-28 15:41:27 -04:00
Peter Hawkins
7ee59e96d3 Instantiate symbolic zeros in the scatter-add transpose rule.
Fixes #776.
2019-05-28 10:30:58 -04:00
Peter Hawkins
9e68d9114e Use more numerically stable formulation of tanh gradient. 2019-05-25 09:59:10 -04:00
Matthew Johnson
ec6e39b9b2 fix negative index handling in index_take
fixes #751
2019-05-21 19:18:29 -07:00
Matthew Johnson
f8aa563db1 make jax.numpy.array(3) give 0D array, not scalar
the mechanism is to use lax.reshape (which was already there) and avoid
the optimization that skipped actually calling reshape_p.bind

fixes #121
2019-05-20 11:49:09 -07:00
Matthew Johnson
42a1ad4307 change dtype promotion behavior for jit-invariance
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`

The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)

Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.

In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
2019-05-19 18:49:16 -07:00
Matthew Johnson
8f9e4b1260 BroadcastedIota needs integer type (fixes #728) 2019-05-17 12:46:11 -07:00
Matthew Johnson
b1fd8e6eb6 add test for DeviceConstant repr 2019-05-17 12:38:45 -07:00
Peter Hawkins
367833bea2 Changes for compatibility with a upcoming Jaxlib update.
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
2019-05-08 20:32:24 -04:00
Peter Hawkins
68f2cb4491 Implement JVP rule for reduce_prod().
This is sufficient to compute first-order derivatives of a product reduction (although not second-order derivatives because there is no JVP for reduce-window-prod).
2019-05-05 14:37:25 -04:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Peter Hawkins
6b4c74b182 Add batching rule for dynamic_update_slice. 2019-04-30 11:48:53 -04:00
Matthew Johnson
076dd0fd99
Merge pull request #636 from google/custom-vjps
add support for custom VJP definitions
2019-04-23 18:46:47 -07:00
Matthew Johnson
85755820bb add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
2019-04-23 17:47:28 -07:00
James Bradbury
55d74d8624 add VJP for lax._select_and_gather_add (3rd-order grad of maxpool) 2019-04-20 17:06:56 -07:00
James Bradbury
b940245730 add VJP for lax._select_and_scatter_add (2nd-order grad of maxpool) 2019-04-20 17:06:35 -07:00
Yutong Zhao
bbf0d5c55e Add brcast to deal with inconsistent shapes. 2019-04-17 20:54:01 -04:00
Yutong Zhao
dfd3d93350 Fix bug in order of y,x in grad of atan2. 2019-04-17 19:57:42 -04:00
Yutong Zhao
36e5ec2189 Add jax tests and fix style. 2019-04-17 19:53:06 -04:00
Yutong Zhao
60539a2612 Implement jvp for atan2 2019-04-17 18:52:22 -04:00
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00