Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.
PiperOrigin-RevId: 497945778
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.
The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.
While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.
.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
Namely, make it so that `split(key, n)[i]` equals `fold_in(key, i)`
for any key and for `0 <= i < n`.
This change affects the observed random bits for a fixed key (indirectly
through splits and folds), so here we guard it behind
`jax.config.jax_threefry_partitionable`. It's not described very well
by the flag name, but it makes for a simple way to bundle together
several random-bit-altering changes as part of the same upgrade cycle.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.
Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.
PiperOrigin-RevId: 470896270
Only handle host-locally sharded `Array`s for now (like in SDAs under
`pmap`). Leaving global sharding for a follow up.
Also re-enable a previously skipped test as a result.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 469885160
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`
PiperOrigin-RevId: 469276609
Specifically:
* Introduce a `physical_avals` view as a custom eltype method. This is
analogous to the existing `aval_to_ir_types`, but where the output
is an aval with a non-custom eltype (and hence a direct
correspondence to TF and to lowerings).
* Change jax2tf to continue tracing with logical avals, but to
maintain TF tensors of corresponding physical shape/dtype, and to
translate to TF operations based on physical avals where relevant.
* Fix up various TF impl rules to follow physical avals. To this end,
add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
carries out the conversion using physical rather than logical avals.
* Write TF impl rules for `random_{seed,split,fold_in,bits}`
primitives. To this end, factor out the part of these primitives'
impl rules that operates on the base array and convert that, pass it
through `_convert_jax_impl` in physical mode.
* Teach the jax2tf test harness how to unwrap key-array-typed outputs
into physical `uint32` arrays that it can use in comparison tests.
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.
We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.
This leans heavily on our recently-introduced extended element type
capabilities.
As a consequence, `vmap`, `scan`, etc. now work.
A sample of changes made to introduce key-element-type arrays:
* Introduce a new element type (`prng.KeyTy`), with the requisite IR
type mapping and device result handlers, as well as lowering rules
for dtype-polymorphic primitive operations.
* Introduce primitives for basic RNG operations: `random_seed`,
`random_bits`, `random_split`, `random_fold_in`. These primitives
essentially delegate to the underlying PRNG implementation (directly
so in their impl rules, and by translating their staged-out form in
lowering rules).
* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
conversion from/to the base `uint32` array. We need this backwards
compatibility, and it's useful for tests.
* Introduce some `vmap`-based helpers to adapt PRNG impls (which
define basic `random_bits`, `split`, etc. on scalars) to the above
batch-polymorphic primitives. Most of the primitives are vectorized,
but `random_fold_in` is a broadcasting binary op.
* Update the `gamma` primitive rules to account for key-element-type
abstract arrays (nice simplification here).
* Give PRNG implementation short string names ("tags") for IR
pretty-printing.
* Update `lax.stop_gradient` to handle opaque dtypes.
* Fix up loop MLIR lowering, which assumed that shaped arrays of all
dtypes have the same physical shape.
* Add new tests (exercising staging, jaxprs, lowerings, ...)
A sample of changes made to rework Python-level PRNG key arrays:
* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
tracers that carry them.
* Patch (only a subset of) standard device array attributes onto PRNG
key arrays.
* Implement various conversion handlers (sharding, constant-creation,
`device_put`).
* Accept PRNG key arrays as input to `lax_numpy.transpose`.
* Update tests and rename some internals.
A sample of extra changes along the way:
* Disallow AD on key-typed arrays in the main API.
* Hoist `random_bits`'s named-shape-handling logic, which used to only
take place in the threefry PRNG's `random_bits` implementation, up
to the new `random_bits` traceable, so that we apply it consistently
across PRNG implementations.
This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.
Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).