From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.
Within jaxprs, we also need to handle transformations with these types.
In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
* allowing ShapedArray to have any object for its 'dtype' attribute
* added core.custom_eltype set
* extended existing handlers for ShapedArray to call the corresponding custom
element type handlers
* mlir lowerings of some fully-element-type-polymorphic primitives
* tests
In this PR, we only actually use these new extension points in tests.
The applications to come that we have in mind are:
* arrays of prngkeys (and even custom prngs, as well as reuse error checking)
* arrays of bounded int type for dynamic shapes (and especially raggedness)
* float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.
Jargon-wise, we may want to distinguish:
* 'eltype' meaning jaxpr element types
* 'dtype' meaning numpy dtypes (an existing convention)
* 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.
We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
* [x] broadcast
* [ ] reshape
* [x] transpose
* [ ] pad
* [x] slice, dynamic_slice, dynamic_update_slice
* [ ] concatenate
* [ ] all_to_all, gather, scatter, all_gather, collective_permute
* [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.
We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.
Co-authored-by: Roy Frostig <frostig@google.com>
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.