Matthew Johnson
642d2dc802
revies optimizers api, fix misc bugs
...
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
f95f1c8dda
fix bugs, make tests pass with skip_checks = False
2019-05-03 12:01:12 -07:00
Matthew Johnson
8e96e2f6df
revert incorrect change to core.valid_jaxtype
2019-05-03 08:24:24 -07:00
Matthew Johnson
7c5d683915
revise sharded result handling, misc cleanup
2019-05-03 08:06:55 -07:00
Matthew Johnson
3f638d3a40
make JaxTuple not subclass tuple, add docstrings
2019-05-01 19:32:48 -07:00
Matthew Johnson
055521fa8e
add DeviceTuples for device-persistent tuples
2019-04-30 17:15:10 -07:00
Matthew Johnson
9c2e1c35b1
prevent jit from treating keyword args as static
...
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
acd9276f0d
add __bool__ to jaxtuples / abstracttuples
2019-03-02 21:43:40 -08:00
Matthew Johnson
a20e8982fa
completed scan (PAIR=hawkinsp@)
2019-03-02 21:27:52 -08:00
Matthew Johnson
45c41d9e58
fix typo in abstract_eval NotImplementedError
2019-02-22 08:13:46 -08:00
Matthew Johnson
4c1fc9cfbd
peval.py works again (some paired w/ @dougalm)
2019-02-22 07:53:28 -08:00
Matthew Johnson
a58c315463
Merge pull request #388 from alexalemi/invert
...
__invert__ doesn't take an argument.
2019-02-15 22:22:58 -08:00
Alex Alemi
d8b3694bfb
__invert__ doesn't take an argument.
2019-02-15 14:09:06 -08:00
Dougal Maclaurin
ce74bc55ce
Handle closed-over tracers in while loop cond and body functions
2019-02-06 12:58:32 -05:00
Matthew Johnson
1e84a3a0fb
make tuple unpacking cause a full_lower
2019-01-07 16:47:13 -08:00
Matthew Johnson
f971415218
add tie_in and full primitives (constant creation)
2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0
Tracer.__len__ should reflect on abstract value
...
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Peter Hawkins
0d4eb6c1e1
Make JAX flake8-clean.
...
Fixes #1 .
2018-12-13 15:29:39 -05:00
Matthew Johnson
7198e09465
enable skip_checks for merging to master
2018-12-11 13:22:07 -05:00
Dougal Maclaurin
1350db2b79
Added higher-order differentiation checks in lax_test and fixed some bugs. Conv tests currently failing.
2018-12-11 13:22:07 -05:00
Matthew Johnson
2ae9a2bc35
source sync
...
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00
Peter Hawkins
5e60639bc5
source sync
...
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113
source sync
...
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
46c6a9170f
sync updates
2018-11-19 07:47:59 -08:00
Matthew Johnson
a30e858e59
populating source tree
2018-11-17 18:03:33 -08:00