A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.
Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.
PiperOrigin-RevId: 470896270
Only handle host-locally sharded `Array`s for now (like in SDAs under
`pmap`). Leaving global sharding for a follow up.
Also re-enable a previously skipped test as a result.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 469885160
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`
PiperOrigin-RevId: 469276609
Specifically:
* Introduce a `physical_avals` view as a custom eltype method. This is
analogous to the existing `aval_to_ir_types`, but where the output
is an aval with a non-custom eltype (and hence a direct
correspondence to TF and to lowerings).
* Change jax2tf to continue tracing with logical avals, but to
maintain TF tensors of corresponding physical shape/dtype, and to
translate to TF operations based on physical avals where relevant.
* Fix up various TF impl rules to follow physical avals. To this end,
add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
carries out the conversion using physical rather than logical avals.
* Write TF impl rules for `random_{seed,split,fold_in,bits}`
primitives. To this end, factor out the part of these primitives'
impl rules that operates on the base array and convert that, pass it
through `_convert_jax_impl` in physical mode.
* Teach the jax2tf test harness how to unwrap key-array-typed outputs
into physical `uint32` arrays that it can use in comparison tests.
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.
We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.
This leans heavily on our recently-introduced extended element type
capabilities.
As a consequence, `vmap`, `scan`, etc. now work.
A sample of changes made to introduce key-element-type arrays:
* Introduce a new element type (`prng.KeyTy`), with the requisite IR
type mapping and device result handlers, as well as lowering rules
for dtype-polymorphic primitive operations.
* Introduce primitives for basic RNG operations: `random_seed`,
`random_bits`, `random_split`, `random_fold_in`. These primitives
essentially delegate to the underlying PRNG implementation (directly
so in their impl rules, and by translating their staged-out form in
lowering rules).
* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
conversion from/to the base `uint32` array. We need this backwards
compatibility, and it's useful for tests.
* Introduce some `vmap`-based helpers to adapt PRNG impls (which
define basic `random_bits`, `split`, etc. on scalars) to the above
batch-polymorphic primitives. Most of the primitives are vectorized,
but `random_fold_in` is a broadcasting binary op.
* Update the `gamma` primitive rules to account for key-element-type
abstract arrays (nice simplification here).
* Give PRNG implementation short string names ("tags") for IR
pretty-printing.
* Update `lax.stop_gradient` to handle opaque dtypes.
* Fix up loop MLIR lowering, which assumed that shaped arrays of all
dtypes have the same physical shape.
* Add new tests (exercising staging, jaxprs, lowerings, ...)
A sample of changes made to rework Python-level PRNG key arrays:
* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
tracers that carry them.
* Patch (only a subset of) standard device array attributes onto PRNG
key arrays.
* Implement various conversion handlers (sharding, constant-creation,
`device_put`).
* Accept PRNG key arrays as input to `lax_numpy.transpose`.
* Update tests and rename some internals.
A sample of extra changes along the way:
* Disallow AD on key-typed arrays in the main API.
* Hoist `random_bits`'s named-shape-handling logic, which used to only
take place in the threefry PRNG's `random_bits` implementation, up
to the new `random_bits` traceable, so that we apply it consistently
across PRNG implementations.
This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.
Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.
This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.
PiperOrigin-RevId: 446761784
Prefer jnp.tile over concatenate.
jnp.tile generates a jaxpr like the following:
```
{ lambda ; a:i32[720192]. let
b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a
c:i32[720192] = squeeze[dimensions=(0,)] b
d:i32[2,720192] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(2, 720192)
] c
e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d
in (e,) }
```
whereas lax.concatenate generates the following jaxpr:
```
{ lambda ; a:i32[720192]. let
b:i32[1440384] = concatenate[dimension=0] a a
in (b,) }
```
It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible.
Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change.
PiperOrigin-RevId: 444287005
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
This is a more descriptive name and a better location (next to other facilities for building XLA IR).
Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.
PiperOrigin-RevId: 404270649
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
threefry2x32 for split and fold_in, while the latter uses untested
heuristics based on calling rng_bit_generator itself as a kind of
hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
sequences from rng_bit_generator (10x iterations) which may be useful on
some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
'hash function' in fold_in
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
Not only is this an experimental API, but also because the
RngBitGenerator is not guaranteed to be stable across compiler versions
(or backends or shardings), let's assert that this JAX PRNG
implementation may not be stable across JAX versions.
Even without that kind of stability, this PRNG is still useful because
compared to effectful RNG primitives, like lax.rng_uniform, this RBG
PRNG will still work correctly with lax.scan and jax.checkpoint (while
still potentially being more performant on some platforms than JAX's
standard PRNG).