--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:
remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!
PiperOrigin-RevId: 468373797
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!
A change that adds jit() decorators on a number of standard library functions was triggering incorrect cache hits for these tests. This is because the payload fields of the MainTrace were not being included in __hash__() and __eq__().
* applied simple find+sed for 'master' -> 'main'
* Rename master->main in JAX API and internals (#4178)
* Started with #4174
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master
Co-authored-by: George Necula <gcnecula@gmail.com>
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.