The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).
For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.
We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)
This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.
The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
There was a mysterious failure on an internal test, and that
mysteriousness means I didn't fully understand the attempted fix, so
best to roll back for now.
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
As one step in tracing user code to a jaxpr using the machinery in
partial_eval.py, we construct a bipartite graph made of JaxprTracer
nodes, corresponding to values in the user code, and recipe nodes
,particularly those corresponding to jaxpr equations, representing
primitive operations. (This representation was put in place in #1224,
since when primitives only had single outputs we could identify each
primitive operation with the JaxprTracer value it produced.) This graph
had reference cycles because each equation recipe points to both its
input and output tracers (as a jaxpr eqn has both input and output vars)
and a tracer must be able to point to the equation recipe that produced
it (for us to toposort the graph from in_tracers to out_tracers in
tracers_to_jaxpr).
Those cycles caused memory leaks. This commit removes the strong
reference cycle using weakrefs. In particular, equation recipes only
hold weak references to their output tracers.
Before this change, we used the core.JaxprEqn struct both to represent
equations in jaxprs (where invars and outvars are instances of the
core.Var class) and to represent equation recipes (where invars and
outvars are instances of the partial_eval.JaxprTracer class). That was a
bit lazy. This commit distinguishes the two as separate JaxprEqn and
JaxprEqnRecipe structs.
Bug find and test code from @trevorcai. Thanks!
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
Using threading within a traced context still won't work, but that is perhaps less important than the ability to call JIT-ted computations from separate threads.
(Revives https://github.com/google/jax/pull/734.)
This fixes a bug where scalar ndarray literals with different dtypes
could hash to the same value. It also makes scalar DeviceArray literals
hashable after #884.