18 Commits

Author SHA1 Message Date
Matthew Johnson
f8aa563db1 make jax.numpy.array(3) give 0D array, not scalar
the mechanism is to use lax.reshape (which was already there) and avoid
the optimization that skipped actually calling reshape_p.bind

fixes #121
2019-05-20 11:49:09 -07:00
Matthew Johnson
42a1ad4307 change dtype promotion behavior for jit-invariance
Here are two desiderata for jax.numpy dtype promotion behavior:
1. follow what NumPy does
2. be invariant to `@jit`

The latter is much more important, so whenever the two are in tension we
prefer the latter. (Also we already can't do a perfect job following
what NumPy does, e.g. around its value-dependent dtype promotion logic.)

Issue #732 showed our code had a special behavior that essentially
handled a case of the former desideratum but also broke the latter. #732
also showed us (again) that our tests really should cover Python
scalars.

In summary, in this commit:
* revise jax.numpy dtype promotion behavior to be invariant to `@jit`
* add Python scalar types to lax_numpy tests
* simplify and update kron implementation to fix dtype issues
2019-05-19 18:49:16 -07:00
Matthew Johnson
8f9e4b1260 BroadcastedIota needs integer type (fixes #728) 2019-05-17 12:46:11 -07:00
Matthew Johnson
b1fd8e6eb6 add test for DeviceConstant repr 2019-05-17 12:38:45 -07:00
Peter Hawkins
367833bea2 Changes for compatibility with a upcoming Jaxlib update.
Shape.abstract_arrays will only accept dtypes, not scalar type objects.
Add long to the set of types known to abstract_arrays in Python 2.
Make api_test.py accepting of long values in shapes.
2019-05-08 20:32:24 -04:00
Peter Hawkins
68f2cb4491 Implement JVP rule for reduce_prod().
This is sufficient to compute first-order derivatives of a product reduction (although not second-order derivatives because there is no JVP for reduce-window-prod).
2019-05-05 14:37:25 -04:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Peter Hawkins
6b4c74b182 Add batching rule for dynamic_update_slice. 2019-04-30 11:48:53 -04:00
Matthew Johnson
076dd0fd99
Merge pull request #636 from google/custom-vjps
add support for custom VJP definitions
2019-04-23 18:46:47 -07:00
Matthew Johnson
85755820bb add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
2019-04-23 17:47:28 -07:00
James Bradbury
55d74d8624 add VJP for lax._select_and_gather_add (3rd-order grad of maxpool) 2019-04-20 17:06:56 -07:00
James Bradbury
b940245730 add VJP for lax._select_and_scatter_add (2nd-order grad of maxpool) 2019-04-20 17:06:35 -07:00
Yutong Zhao
bbf0d5c55e Add brcast to deal with inconsistent shapes. 2019-04-17 20:54:01 -04:00
Yutong Zhao
dfd3d93350 Fix bug in order of y,x in grad of atan2. 2019-04-17 19:57:42 -04:00
Yutong Zhao
36e5ec2189 Add jax tests and fix style. 2019-04-17 19:53:06 -04:00
Yutong Zhao
60539a2612 Implement jvp for atan2 2019-04-17 18:52:22 -04:00
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00