59 Commits

Author SHA1 Message Date
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
3c37a3260a Update linearize to no-tuple version 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
20fad746a8 De-dup equations with multiple lhs vars when creating a jaxpr
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Jamie Townsend
9b19afd4d9 Implement Jaxpr __hash__
This means that primitives like scatter, which have a Jaxpr in their
**params, will get cache hits appropriately.
2019-08-21 14:27:23 +01:00
Brian Patton
d07107af5a
Makes the Tracer object weakref-able 2019-08-16 17:18:44 -05:00
Peter Hawkins
6dc730a5f4 Make JAX tracer state thread-local. Allows performing traces in separate threads.
Using threading within a traced context still won't work, but that is perhaps less important than the ability to call JIT-ted computations from separate threads.

(Revives https://github.com/google/jax/pull/734.)
2019-08-09 13:55:20 -04:00
Jamie Townsend
21a69884fd call_wrapped in core.call_impl 2019-07-22 17:09:03 +01:00
Matthew Johnson
5aef18f897 improve literal hashing logic
This fixes a bug where scalar ndarray literals with different dtypes
could hash to the same value. It also makes scalar DeviceArray literals
hashable after #884.
2019-06-19 10:32:55 -07:00
Matthew Johnson
b53bccc5d0 make more literals nontrivially hashable 2019-06-18 21:51:51 -07:00
Matthew Johnson
221426fadc de-duplicate constants staged into jaxprs
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-06-18 08:09:37 -07:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00
Matthew Johnson
310103f578 try a tweak on Literal for more cache hits 2019-05-28 22:50:52 -07:00
Matthew Johnson
9c931ddebe allow more types to be jaxpr literals, fixes #772 2019-05-28 22:38:06 -07:00
Matthew Johnson
d27bc0a129 add literals to jaxprs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-13 08:48:13 -07:00
Matthew Johnson
6e9718a229 add pretty-printing to TypedJaxpr 2019-05-11 13:28:47 -07:00
Matthew Johnson
65202821df improve core.typed_jaxpr arg typechecks 2019-05-11 10:45:14 -07:00
Matthew Johnson
4fcd96f926 make tests pass with skip_checks = False 2019-05-10 22:07:54 -07:00
Matthew Johnson
29e67f0119 scan bug fixed, other cleanup 2019-05-10 15:52:12 -07:00
Matthew Johnson
5cfa18015c fix things we broke on the path to scan 2019-05-10 14:00:21 -07:00
Matthew Johnson
360e39756f must guarantee progress on lattice...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:38 -07:00
Matthew Johnson
085f06e4b6 add some PartialVal invariants
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:24 -07:00
Matthew Johnson
c08b9fee47 remove const_env from check_jaxpr, add scan trans 2019-05-08 17:41:36 -07:00
Matthew Johnson
15d783a836 Merge remote-tracking branch 'origin/master' into differentiable-scan 2019-05-08 13:42:44 -07:00
Matthew Johnson
444cda493a add underscores, rename scan_initial -> scan 2019-05-08 13:41:27 -07:00
Matthew Johnson
e736a0a9a1 cleanup: remove call_initial, add xla pat_fmap 2019-05-08 13:41:27 -07:00
Matthew Johnson
4c2ec3e442 ship it
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:25 -07:00
Matthew Johnson
0988f6d8d5 pattern unpacking at jaxpr top-level (pair w/ @dougalm)
next step is to handle that new complexity in our jaxpr munging...

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:21 -07:00
Matthew Johnson
a17f8e4ca8 add jaxpr eqn structured input, transpose progress
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:19 -07:00
Matthew Johnson
1c9035efca start scan transpose, but "nonlinear pack"!!
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:17 -07:00
Matthew Johnson
6736823021 victory! patial eval of scan (+ linearize!)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:15 -07:00
Matthew Johnson
d03cdc6397 introduce typedjaxpr to carry around literals etc
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:10 -07:00
Matthew Johnson
19e0f8de45 fix tuple unpacking problems 2019-05-06 22:43:31 -07:00
Matthew Johnson
bf6c15b59a update pmap to flatten correctly (was a perf bug)
also temporarily avoid DeviceTuples in optimizer states
2019-05-06 12:09:54 -07:00
Matthew Johnson
642d2dc802 revies optimizers api, fix misc bugs
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
f95f1c8dda fix bugs, make tests pass with skip_checks = False 2019-05-03 12:01:12 -07:00
Matthew Johnson
8e96e2f6df revert incorrect change to core.valid_jaxtype 2019-05-03 08:24:24 -07:00
Matthew Johnson
7c5d683915 revise sharded result handling, misc cleanup 2019-05-03 08:06:55 -07:00
Matthew Johnson
3f638d3a40 make JaxTuple not subclass tuple, add docstrings 2019-05-01 19:32:48 -07:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
acd9276f0d add __bool__ to jaxtuples / abstracttuples 2019-03-02 21:43:40 -08:00
Matthew Johnson
a20e8982fa completed scan (PAIR=hawkinsp@) 2019-03-02 21:27:52 -08:00
Matthew Johnson
45c41d9e58 fix typo in abstract_eval NotImplementedError 2019-02-22 08:13:46 -08:00
Matthew Johnson
4c1fc9cfbd peval.py works again (some paired w/ @dougalm) 2019-02-22 07:53:28 -08:00
Matthew Johnson
a58c315463
Merge pull request #388 from alexalemi/invert
__invert__ doesn't take an argument.
2019-02-15 22:22:58 -08:00
Alex Alemi
d8b3694bfb
__invert__ doesn't take an argument. 2019-02-15 14:09:06 -08:00
Dougal Maclaurin
ce74bc55ce Handle closed-over tracers in while loop cond and body functions 2019-02-06 12:58:32 -05:00
Matthew Johnson
1e84a3a0fb make tuple unpacking cause a full_lower 2019-01-07 16:47:13 -08:00
Matthew Johnson
f971415218 add tie_in and full primitives (constant creation) 2018-12-18 09:16:59 -08:00