5 Commits

Author SHA1 Message Date
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
13a17286df
stop_gradient_p -> ad_util.py, re-enable some mypy (#2806) 2020-04-23 13:12:24 -07:00
Matthew Johnson
903010b7b9 disable mypy checks causing new errors 2020-04-23 09:28:14 -07:00
Matthew Johnson
f99720b70a add type annotations to core.py tracing machinery
also add .copy() method to core.trace_state global trace state
2020-03-28 14:58:35 -07:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00