This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* Added clearer error message for tracers in numpy.split
Now we print:
ConcretizationTypeError: Abstract tracer value where concrete value is expected (in
jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid
tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray>
* Fixed tests, slight change to the error message
* Expanded the FAQ entry about abstract tracers for higher-order primitives
* Added clarification for tracers inside jit of grad
* Updated FAQ language in response to reviews
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.