119 Commits

Author SHA1 Message Date
Matthew Johnson
8bc4e379f5 make DeviceArray.__hash__ raise an error
Fixes #883 by adjusting the caching logic we use not to rely on
DeviceArray being hashable, also closing a long-standing TODO.

Also fixed a minor bug in lax.py which caused scalar DeviceArrays to
appear in the padding params of some convolutions (from using `max`
instead of `_max` in lax.py).
2019-06-19 10:12:13 -07:00
Matthew Johnson
eb01b8bfef improve linearize error message
fixes #871
2019-06-18 09:31:50 -07:00
Matthew Johnson
7c308dca29 implement reviewer suggestions 2019-06-11 06:45:09 -07:00
Matthew Johnson
121d78129b docstring improvements from @skye comments 2019-06-06 10:12:07 -07:00
Matthew Johnson
12b90ddf68 fix typo 2019-06-05 19:19:34 -07:00
Matthew Johnson
a6c41a323c finish drafting defvjp/defjvp docstrings 2019-06-05 19:13:33 -07:00
Matthew Johnson
cfaa49f884 improve custom_gradient docstring 2019-06-05 18:02:15 -07:00
Matthew Johnson
948ec8fbe8 add docstrings for defjvp and defjvp2 2019-06-05 17:56:18 -07:00
Matthew Johnson
ab20f0292c add docstring for defjvp_all 2019-06-05 17:34:14 -07:00
Matthew Johnson
720dec4072 add custom_gradient 2019-06-05 13:48:04 -07:00
Matthew Johnson
372d60bb08 add docstring to custom_transforms 2019-06-05 13:20:44 -07:00
Matthew Johnson
35e5e64416 make custom_transforms handle pytrees, add api.defvjp
With advice from @dougalm!
2019-06-05 12:29:25 -07:00
Matthew Johnson
dda95df519 fix duck typing in jax.eval_shape (cf. #798) 2019-06-01 09:50:24 -07:00
Matthew Johnson
11c512a194 add jax.eval_shape, fixes #798 2019-06-01 09:36:46 -07:00
Matthew Johnson
93e1143373 improve disable_jit docstring 2019-06-01 08:30:25 -07:00
Matthew Johnson
5a049feea9 further improve grad-of-nonscalar error message 2019-05-28 21:10:09 -07:00
Matthew Johnson
743727325f improve error message for grad of non-scalar funs 2019-05-28 20:25:38 -07:00
Matthew Johnson
a193b3592b when xla_computation sees no kwargs, don't make () 2019-05-22 14:39:34 -07:00
Matthew Johnson
b6031ffdd7 avoid packing leaf outputs for jit/pmap funs 2019-05-17 07:36:52 -07:00
Matthew Johnson
52a2fb3280 typo fixes in pmap docstring
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-05-15 08:25:32 -07:00
Matthew Johnson
38c6a8e899 add pmap docstring
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-05-15 08:13:30 -07:00
Matthew Johnson
29629931a1
Merge pull request #704 from google/differentiable-scan
Differentiable scan!
2019-05-13 10:26:09 -07:00
Matthew Johnson
65202821df improve core.typed_jaxpr arg typechecks 2019-05-11 10:45:14 -07:00
Matthew Johnson
5aea10c7b3 make static_argnums cache on value when possible
fixes #691
2019-05-09 20:00:24 -07:00
Matthew Johnson
19e0f8de45 fix tuple unpacking problems 2019-05-06 22:43:31 -07:00
Matthew Johnson
690301357d improve pmap error messages 2019-05-06 16:18:34 -07:00
Matthew Johnson
bf6c15b59a update pmap to flatten correctly (was a perf bug)
also temporarily avoid DeviceTuples in optimizer states
2019-05-06 12:09:54 -07:00
Matthew Johnson
642d2dc802 revies optimizers api, fix misc bugs
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
15a4554ffb flatten out pytrees in jit at the api.py level 2019-05-03 11:39:37 -07:00
Matthew Johnson
7c5d683915 revise sharded result handling, misc cleanup 2019-05-03 08:06:55 -07:00
Matthew Johnson
87a150e567 add a tree_util.py module-level docstring 2019-05-02 08:02:01 -07:00
Peter Hawkins
43a31d0125 Update API documentation to help clarify https://github.com/google/jax/issues/648 . 2019-04-30 10:04:36 -04:00
Matthew Johnson
0cc8d7c2b1 update docstrings to fix #637 2019-04-23 18:21:33 -07:00
Matthew Johnson
d7096a42c5
make jacrev work w/ complex inputs, update errors (#610)
* make jacrev work w/ complex inputs, update errors

* fix up complex handling in jacfwd and jacrev
2019-04-13 13:22:45 -07:00
Matthew Johnson
849ea87b33 tree-map the real dtype check in api.py 2019-04-12 13:29:07 -07:00
Matthew Johnson
18671fa027 add error checks so that #603 isn't silent fail 2019-04-12 12:01:19 -07:00
Matthew Johnson
de2a5f725d add warning, fix typo in kwargs test and bug 2019-04-11 06:58:09 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
054d210a32 fix typo in xla_computation 2019-04-04 17:40:48 -07:00
Matthew Johnson
31e35b204a make np.reshape reflect on argument method
Reshapes should be cheap, but because `np.reshape` would always call
`lax.reshape` regardless of whether it was given a raw ndarray or one of
our DeviceArrays, it would sometimes copy ndarray data into a
DeviceArray. Our general policy is always to copy data to the device
(and lazily leave it there until the host needs it), but this policy
fell down here because of doing a reshape on data before a `pmap`'d
computation: the op-by-op `np.reshape` call put all the data on one
device, then the following `pmap` function had to copy everything back
to the host then re-distribute it to multiple devices. (The location of
what logical shards need to go on which device is computation-dependent,
so it's not something we can reliably do before actually getting to
execute the specific `pmap` function of interest.)

This commit makes a simple change in the `jax.numpy` layer to make
`np.reshape(x, shape)` try calling `x.reshape(shape)`, so that when `x`
is an ndarray it will stay an ndarray (without any transfer). This
change is not in the `lax` layer so that the `lax` policy can stay
simple (always copy to device). We might revise these decisions in the
future, and for now they're just under-the-hood optimizations, with the
ability for a user to directly call `onp` or `lax` if they want to be
careful about where data lives.

This commit also changed `jax.replicate` to replicate (with
`onp.broadcast_to`, which uses stride tricks instead of allocating more
memory) data to have a leading axis of size `device_count`. The previous
solution, based on `pmap`ing a function with a lexical closure, caused
re-compilation on every call.
2019-04-04 11:25:23 -07:00
Matthew Johnson
61ce283f3e graphviz: concat strings only at the end 2019-04-02 21:18:20 -07:00
Matthew Johnson
51d2722185 add graphviz-dumping function 2019-04-02 18:34:19 -07:00
Matthew Johnson
f17d31fdf2 rename xla_pcall -> xla_pmap 2019-04-01 17:21:50 -07:00
Matthew Johnson
aa6cebff44 fix typo in vmap (fixes #536) 2019-03-29 08:03:58 -07:00
Matthew Johnson
18bc5936a7
Merge pull request #525 from google/jarrett-jvps-2
add a function for elementwise jacobian accumulation
2019-03-25 18:37:05 -07:00
Matthew Johnson
850e8a756a fix typo in linearize docstring 2019-03-25 11:31:44 -07:00
Matthew Johnson
5704624cf9 attempt to fix readtheddocs.io formatting 2019-03-25 11:29:44 -07:00
Matthew Johnson
595a9800d4 dedent code 2019-03-25 11:11:57 -07:00
Matthew Johnson
54ac87957c add clarification about linearize vs jvp+vmap 2019-03-25 11:03:03 -07:00
Matthew Johnson
a169e534a8 add docstring for jax.linearize (fixes #526) 2019-03-25 10:37:24 -07:00