* Added clearer error message for tracers in numpy.split
Now we print:
ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>
* Fixed tests, slight change to the error message
* Expanded the FAQ entry about abstract tracers for higher-order primitives
* Added clarification for tracers inside jit of grad
* Updated FAQ language in response to reviews
Introduced two new constructors for PartialVal: unknown and known.
These should make it easier to read the code where we construct
PartialVal:
* instead of PartialVal((aval, core.unit) we use PartialVal.unknown(aval)
* instead of PartialVal((None, pval)) we use PartialVal.known(pval)
Also disabled some new tests in random_tests.py on Mac. They segfault,
apparently due to the same issue #432.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).
For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
Before, bound_subjaxprs was a tuple (0 or 1 values) of
a pair of a Jaxpr and its constant values. Now we close up all such Jaxprs
such that they do not take constvars and their constant values are part of the
arguments.
We also rename bound_subjaxprs to bound_subjaxpr (an optional Jaxpr)
This is first part of a simplification. In a subsequent PR I will move
the bound_subjaxpr into params, as for most higher-order primitives.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.
The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
There was a mysterious failure on an internal test, and that
mysteriousness means I didn't fully understand the attempted fix, so
best to roll back for now.
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>