13 Commits

Author SHA1 Message Date
Peter Hawkins
68f2cb4491 Implement JVP rule for reduce_prod().
This is sufficient to compute first-order derivatives of a product reduction (although not second-order derivatives because there is no JVP for reduce-window-prod).
2019-05-05 14:37:25 -04:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Peter Hawkins
6b4c74b182 Add batching rule for dynamic_update_slice. 2019-04-30 11:48:53 -04:00
Matthew Johnson
076dd0fd99
Merge pull request #636 from google/custom-vjps
add support for custom VJP definitions
2019-04-23 18:46:47 -07:00
Matthew Johnson
85755820bb add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
2019-04-23 17:47:28 -07:00
James Bradbury
55d74d8624 add VJP for lax._select_and_gather_add (3rd-order grad of maxpool) 2019-04-20 17:06:56 -07:00
James Bradbury
b940245730 add VJP for lax._select_and_scatter_add (2nd-order grad of maxpool) 2019-04-20 17:06:35 -07:00
Yutong Zhao
bbf0d5c55e Add brcast to deal with inconsistent shapes. 2019-04-17 20:54:01 -04:00
Yutong Zhao
dfd3d93350 Fix bug in order of y,x in grad of atan2. 2019-04-17 19:57:42 -04:00
Yutong Zhao
36e5ec2189 Add jax tests and fix style. 2019-04-17 19:53:06 -04:00
Yutong Zhao
60539a2612 Implement jvp for atan2 2019-04-17 18:52:22 -04:00
Peter Hawkins
407306293f Update lax documentation to reflect new code organization. 2019-04-15 12:16:14 -04:00
Matthew Johnson
0cf14837c9 make a lax package, revert control flow names (#607)
c.f. #597
pair=skyewm
2019-04-12 16:28:40 -07:00