The `JaxprTracerTuple` appears to no longer exist and the function could return `None` if `pv` was another type other than `AbstractValue`, instead of raising an error.
* Add precision to jax.numpy functions that use lax.dot_general
* Test precision argument
* check default precision
* test with jaxprs
* Document precision
bfloat16 support is still immature, but this PR adds some initial support.
Fixes#76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.
The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:
implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
When #1667 inlined a function into its caller, it mixed up two distinct
values referred to as `nrep` in the two functions: num_global_replicas
vs num_local_replicas. The result caused errors on multi-host setups.
Co-authored-by: Jonathan Heek <jheek@google.com>
* Add support for scipy.ndimage.map_coordinates with order=1
Higher dimensional interpolation will be a bit trickier, but this should
already be useful.
* move around docstring
* dtype fixes, more tests
* fixup float32 tests
* Handle order=0
* Tests for errors from map_coordinates
As one step in tracing user code to a jaxpr using the machinery in
partial_eval.py, we construct a bipartite graph made of JaxprTracer
nodes, corresponding to values in the user code, and recipe nodes
,particularly those corresponding to jaxpr equations, representing
primitive operations. (This representation was put in place in #1224,
since when primitives only had single outputs we could identify each
primitive operation with the JaxprTracer value it produced.) This graph
had reference cycles because each equation recipe points to both its
input and output tracers (as a jaxpr eqn has both input and output vars)
and a tracer must be able to point to the equation recipe that produced
it (for us to toposort the graph from in_tracers to out_tracers in
tracers_to_jaxpr).
Those cycles caused memory leaks. This commit removes the strong
reference cycle using weakrefs. In particular, equation recipes only
hold weak references to their output tracers.
Before this change, we used the core.JaxprEqn struct both to represent
equations in jaxprs (where invars and outvars are instances of the
core.Var class) and to represent equation recipes (where invars and
outvars are instances of the partial_eval.JaxprTracer class). That was a
bit lazy. This commit distinguishes the two as separate JaxprEqn and
JaxprEqnRecipe structs.
Bug find and test code from @trevorcai. Thanks!
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.