* Canonicalize the shape in the wrapper functions in random.py.
This lets the user be more sloppy in using numpy arrays and statically
known DeviceArrays for shapes, and still hit the jit cache. When they
are not, the error is improved.
* Fix some errors.
* No need for the Poly workaround.
* Bypass canonicalization for None shapes in random.py.
Freevars played a very small role, and they can be folded with
the invars. This simplifies the Jaxpr data structure.We remove
the `freevars` field from Jaxpr and from the bound_subjaxprs.
The only non-trivial change is for xla_pmap, where we need
to carry one extra parameter `mapped_invars` with a bitmap
to encode which invars are mapped and which are broadcast.
Previously, the freevars were broadcast.
* shapecheck of jit, device_put, broadcast_in_dim, better error for unsupported ops, parse multi-digit integer literals
* WIP shapecheck np.pad
* Implement shapecheck of gather, pad
* Fix shapecheck of pad
* Implement polymorphic shape rule for (strided/dilated) convolution, refactor
* Cleanup
* Fix
* Remove all polymorphic shape rules, reuse shape rules instead.
* Register shape_rule for all standard_primitives
* Remove ShapeExpr, canonicalize_poly, renames
* Complete shapecheck(binop) implementation, remove special cases for polymorphic shapes
* Allow Poly of form d*poly + k to be divided by d
* Fix bug, inline poly_without_zeros.
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes#1431
See https://github.com/google/jax/pull/1668 for more.
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.
Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.
Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').
Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.
See https://github.com/google/jax/pull/1749 for more.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
bfloat16 support is still immature, but this PR adds some initial support.
Fixes#76, at least enough that we can declare it fixed and open specific issues for specific bfloat16 problems.
The main awkwardness that this change deals with is that classic NumPy doesn't understand bfloat16 promotion rules, so we must:
implement our own type promotion operators that understand bfloat16 types
wrap a number of the reference implementations in tests to temporarily cast to float32 for computation.
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
* Move internal type-related functions into a new (internal) jax.types module.
Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.
Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.
* Rename jax.types to jax.dtypes.
* s/types/dtypes/ in tests.