53 Commits

Author SHA1 Message Date
Peter Hawkins
6dc730a5f4 Make JAX tracer state thread-local. Allows performing traces in separate threads.
Using threading within a traced context still won't work, but that is perhaps less important than the ability to call JIT-ted computations from separate threads.

(Revives https://github.com/google/jax/pull/734.)
2019-08-09 13:55:20 -04:00
Jamie Townsend
21a69884fd call_wrapped in core.call_impl 2019-07-22 17:09:03 +01:00
Matthew Johnson
5aef18f897 improve literal hashing logic
This fixes a bug where scalar ndarray literals with different dtypes
could hash to the same value. It also makes scalar DeviceArray literals
hashable after #884.
2019-06-19 10:32:55 -07:00
Matthew Johnson
b53bccc5d0 make more literals nontrivially hashable 2019-06-18 21:51:51 -07:00
Matthew Johnson
221426fadc de-duplicate constants staged into jaxprs
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-06-18 08:09:37 -07:00
Matthew Johnson
778435a90b undo #503 in favor of new literal staging method 2019-05-29 08:12:05 -07:00
Matthew Johnson
310103f578 try a tweak on Literal for more cache hits 2019-05-28 22:50:52 -07:00
Matthew Johnson
9c931ddebe allow more types to be jaxpr literals, fixes #772 2019-05-28 22:38:06 -07:00
Matthew Johnson
d27bc0a129 add literals to jaxprs
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-13 08:48:13 -07:00
Matthew Johnson
6e9718a229 add pretty-printing to TypedJaxpr 2019-05-11 13:28:47 -07:00
Matthew Johnson
65202821df improve core.typed_jaxpr arg typechecks 2019-05-11 10:45:14 -07:00
Matthew Johnson
4fcd96f926 make tests pass with skip_checks = False 2019-05-10 22:07:54 -07:00
Matthew Johnson
29e67f0119 scan bug fixed, other cleanup 2019-05-10 15:52:12 -07:00
Matthew Johnson
5cfa18015c fix things we broke on the path to scan 2019-05-10 14:00:21 -07:00
Matthew Johnson
360e39756f must guarantee progress on lattice...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:38 -07:00
Matthew Johnson
085f06e4b6 add some PartialVal invariants
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-09 21:55:24 -07:00
Matthew Johnson
c08b9fee47 remove const_env from check_jaxpr, add scan trans 2019-05-08 17:41:36 -07:00
Matthew Johnson
15d783a836 Merge remote-tracking branch 'origin/master' into differentiable-scan 2019-05-08 13:42:44 -07:00
Matthew Johnson
444cda493a add underscores, rename scan_initial -> scan 2019-05-08 13:41:27 -07:00
Matthew Johnson
e736a0a9a1 cleanup: remove call_initial, add xla pat_fmap 2019-05-08 13:41:27 -07:00
Matthew Johnson
4c2ec3e442 ship it
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:25 -07:00
Matthew Johnson
0988f6d8d5 pattern unpacking at jaxpr top-level (pair w/ @dougalm)
next step is to handle that new complexity in our jaxpr munging...

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:21 -07:00
Matthew Johnson
a17f8e4ca8 add jaxpr eqn structured input, transpose progress
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:19 -07:00
Matthew Johnson
1c9035efca start scan transpose, but "nonlinear pack"!!
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:17 -07:00
Matthew Johnson
6736823021 victory! patial eval of scan (+ linearize!)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:15 -07:00
Matthew Johnson
d03cdc6397 introduce typedjaxpr to carry around literals etc
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:10 -07:00
Matthew Johnson
19e0f8de45 fix tuple unpacking problems 2019-05-06 22:43:31 -07:00
Matthew Johnson
bf6c15b59a update pmap to flatten correctly (was a perf bug)
also temporarily avoid DeviceTuples in optimizer states
2019-05-06 12:09:54 -07:00
Matthew Johnson
642d2dc802 revies optimizers api, fix misc bugs
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
f95f1c8dda fix bugs, make tests pass with skip_checks = False 2019-05-03 12:01:12 -07:00
Matthew Johnson
8e96e2f6df revert incorrect change to core.valid_jaxtype 2019-05-03 08:24:24 -07:00
Matthew Johnson
7c5d683915 revise sharded result handling, misc cleanup 2019-05-03 08:06:55 -07:00
Matthew Johnson
3f638d3a40 make JaxTuple not subclass tuple, add docstrings 2019-05-01 19:32:48 -07:00
Matthew Johnson
055521fa8e add DeviceTuples for device-persistent tuples 2019-04-30 17:15:10 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
acd9276f0d add __bool__ to jaxtuples / abstracttuples 2019-03-02 21:43:40 -08:00
Matthew Johnson
a20e8982fa completed scan (PAIR=hawkinsp@) 2019-03-02 21:27:52 -08:00
Matthew Johnson
45c41d9e58 fix typo in abstract_eval NotImplementedError 2019-02-22 08:13:46 -08:00
Matthew Johnson
4c1fc9cfbd peval.py works again (some paired w/ @dougalm) 2019-02-22 07:53:28 -08:00
Matthew Johnson
a58c315463
Merge pull request #388 from alexalemi/invert
__invert__ doesn't take an argument.
2019-02-15 22:22:58 -08:00
Alex Alemi
d8b3694bfb
__invert__ doesn't take an argument. 2019-02-15 14:09:06 -08:00
Dougal Maclaurin
ce74bc55ce Handle closed-over tracers in while loop cond and body functions 2019-02-06 12:58:32 -05:00
Matthew Johnson
1e84a3a0fb make tuple unpacking cause a full_lower 2019-01-07 16:47:13 -08:00
Matthew Johnson
f971415218 add tie_in and full primitives (constant creation) 2018-12-18 09:16:59 -08:00
Matthew Johnson
bfe653c6b0 Tracer.__len__ should reflect on abstract value
This old implementation, which was meant to be revised but which we
forgot about, caused a surprising slowdown: if x were a traced array of
size 50000, evaluating len(x) would create 50000 traced temporary
objects, which led to a lot of overhead! That came up in our
implementation of jax.random.shuffle, which happened to call len()
instead of x.shape[axis] (even though it should have been using x.size
anyway, according to tjablin@'s code that it's based on).
2018-12-15 20:07:10 -08:00
Peter Hawkins
0d4eb6c1e1 Make JAX flake8-clean.
Fixes #1.
2018-12-13 15:29:39 -05:00
Matthew Johnson
7198e09465 enable skip_checks for merging to master 2018-12-11 13:22:07 -05:00
Dougal Maclaurin
1350db2b79 Added higher-order differentiation checks in lax_test and fixed some bugs. Conv tests currently failing. 2018-12-11 13:22:07 -05:00
Matthew Johnson
2ae9a2bc35 source sync
PiperOrigin-RevId: 222461242
2018-11-21 20:32:16 -08:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00