The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.
I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`
PiperOrigin-RevId: 469276609
Specifically:
* Introduce a `physical_avals` view as a custom eltype method. This is
analogous to the existing `aval_to_ir_types`, but where the output
is an aval with a non-custom eltype (and hence a direct
correspondence to TF and to lowerings).
* Change jax2tf to continue tracing with logical avals, but to
maintain TF tensors of corresponding physical shape/dtype, and to
translate to TF operations based on physical avals where relevant.
* Fix up various TF impl rules to follow physical avals. To this end,
add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
carries out the conversion using physical rather than logical avals.
* Write TF impl rules for `random_{seed,split,fold_in,bits}`
primitives. To this end, factor out the part of these primitives'
impl rules that operates on the base array and convert that, pass it
through `_convert_jax_impl` in physical mode.
* Teach the jax2tf test harness how to unwrap key-array-typed outputs
into physical `uint32` arrays that it can use in comparison tests.
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.
We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.
This leans heavily on our recently-introduced extended element type
capabilities.
As a consequence, `vmap`, `scan`, etc. now work.
A sample of changes made to introduce key-element-type arrays:
* Introduce a new element type (`prng.KeyTy`), with the requisite IR
type mapping and device result handlers, as well as lowering rules
for dtype-polymorphic primitive operations.
* Introduce primitives for basic RNG operations: `random_seed`,
`random_bits`, `random_split`, `random_fold_in`. These primitives
essentially delegate to the underlying PRNG implementation (directly
so in their impl rules, and by translating their staged-out form in
lowering rules).
* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
conversion from/to the base `uint32` array. We need this backwards
compatibility, and it's useful for tests.
* Introduce some `vmap`-based helpers to adapt PRNG impls (which
define basic `random_bits`, `split`, etc. on scalars) to the above
batch-polymorphic primitives. Most of the primitives are vectorized,
but `random_fold_in` is a broadcasting binary op.
* Update the `gamma` primitive rules to account for key-element-type
abstract arrays (nice simplification here).
* Give PRNG implementation short string names ("tags") for IR
pretty-printing.
* Update `lax.stop_gradient` to handle opaque dtypes.
* Fix up loop MLIR lowering, which assumed that shaped arrays of all
dtypes have the same physical shape.
* Add new tests (exercising staging, jaxprs, lowerings, ...)
A sample of changes made to rework Python-level PRNG key arrays:
* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
tracers that carry them.
* Patch (only a subset of) standard device array attributes onto PRNG
key arrays.
* Implement various conversion handlers (sharding, constant-creation,
`device_put`).
* Accept PRNG key arrays as input to `lax_numpy.transpose`.
* Update tests and rename some internals.
A sample of extra changes along the way:
* Disallow AD on key-typed arrays in the main API.
* Hoist `random_bits`'s named-shape-handling logic, which used to only
take place in the threefry PRNG's `random_bits` implementation, up
to the new `random_bits` traceable, so that we apply it consistently
across PRNG implementations.
This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.
Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:
remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!
PiperOrigin-RevId: 468373797
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).
This change languished until we could land #11830 / #11950 and friends. But now
we can!
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.
Within jaxprs, we also need to handle transformations with these types.
In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
* allowing ShapedArray to have any object for its 'dtype' attribute
* added core.custom_eltype set
* extended existing handlers for ShapedArray to call the corresponding custom
element type handlers
* mlir lowerings of some fully-element-type-polymorphic primitives
* tests
In this PR, we only actually use these new extension points in tests.
The applications to come that we have in mind are:
* arrays of prngkeys (and even custom prngs, as well as reuse error checking)
* arrays of bounded int type for dynamic shapes (and especially raggedness)
* float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.
Jargon-wise, we may want to distinguish:
* 'eltype' meaning jaxpr element types
* 'dtype' meaning numpy dtypes (an existing convention)
* 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.
We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
* [x] broadcast
* [ ] reshape
* [x] transpose
* [ ] pad
* [x] slice, dynamic_slice, dynamic_update_slice
* [ ] concatenate
* [ ] all_to_all, gather, scatter, all_gather, collective_permute
* [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.
We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.
Co-authored-by: Roy Frostig <frostig@google.com>
--
a001c52f878824cd1c0a67c73d9d318ed30286c9 by Matthew Johnson <mattjj@google.com>:
[dynamic-shapes] basic jvp working, including with broadcast
PiperOrigin-RevId: 456822732
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451320477
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451268007