25 Commits

Author SHA1 Message Date
Matthew Johnson
642d2dc802 revies optimizers api, fix misc bugs
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
f95f1c8dda fix bugs, make tests pass with skip_checks = False 2019-05-03 12:01:12 -07:00
Matthew Johnson
8e96e2f6df revert incorrect change to core.valid_jaxtype 2019-05-03 08:24:24 -07:00
Matthew Johnson
7c5d683915 revise sharded result handling, misc cleanup 2019-05-03 08:06:55 -07:00
Matthew Johnson
3f638d3a40 make JaxTuple not subclass tuple, add docstrings 2019-05-01 19:32:48 -07:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
acd9276f0d add __bool__ to jaxtuples / abstracttuples 2019-03-02 21:43:40 -08:00
Matthew Johnson
a20e8982fa completed scan (PAIR=hawkinsp@) 2019-03-02 21:27:52 -08:00
Matthew Johnson
45c41d9e58 fix typo in abstract_eval NotImplementedError 2019-02-22 08:13:46 -08:00
Matthew Johnson
4c1fc9cfbd peval.py works again (some paired w/ @dougalm) 2019-02-22 07:53:28 -08:00
Matthew Johnson
a58c315463
Merge pull request #388 from alexalemi/invert
__invert__ doesn't take an argument.
2019-02-15 22:22:58 -08:00
Alex Alemi
d8b3694bfb
__invert__ doesn't take an argument. 2019-02-15 14:09:06 -08:00
Dougal Maclaurin
ce74bc55ce Handle closed-over tracers in while loop cond and body functions 2019-02-06 12:58:32 -05:00
Matthew Johnson
1e84a3a0fb make tuple unpacking cause a full_lower 2019-01-07 16:47:13 -08:00
Matthew Johnson
f971415218 add tie_in and full primitives (constant creation) 2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0 Tracer.__len__ should reflect on abstract value
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Peter Hawkins
0d4eb6c1e1 Make JAX flake8-clean.
Fixes #1.
2018-12-13 15:29:39 -05:00
Matthew Johnson
7198e09465 enable skip_checks for merging to master 2018-12-11 13:22:07 -05:00
Dougal Maclaurin
1350db2b79 Added higher-order differentiation checks in lax_test and fixed some bugs. Conv tests currently failing. 2018-12-11 13:22:07 -05:00
Matthew Johnson
2ae9a2bc35 source sync
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
46c6a9170f sync updates 2018-11-19 07:47:59 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00