Currently block_shape is tuple[int | None, …]. We propose generalizing block_shape to take in more types in the tuple to more generally support:
* Squeeze dimension (currently None, could be pl.Squeezed())
* Unblocked: currently the entire index_map needs to be Unblocked or not. This will allow individual indices to be Blocked/Unblocked, e.g. pl.BlockSpec((pl.Unblocked(...), 512), …)
* Ragged sizes: the index_map will return a pl.ds with a dynamic size (bounded by some something). For example: pl.BlockSpec((pl.DynamicSizedSlice(512), 1024), lambda i, j: (pl.ds(...), j).
This will make BlockSpecs a lot more flexible and will enable things like doing arbitrary slicing in things like pipeline emitter.
PiperOrigin-RevId: 748881960
Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`.
PiperOrigin-RevId: 748867490
The `device_type` can be inferred from the `tpu.core_type` on the kernel.
`kernel_type`, on the other hand, can also be used to define specialized
lowering rules for scalar/vector subcores.
PiperOrigin-RevId: 748794989
We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs.
PiperOrigin-RevId: 748735303
When we run the program with "--xla_jf_bounds_check=true", we can selectively disable bounds checks for pallas kernels now.
PiperOrigin-RevId: 748193719
Generally, we want to maintain that key data backing a `PRNGKeyArray` is a `jax.Array`. This change converts NumPy arrays on construction.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 748077900
The `vectorized` argument to `pure_callback` and `ffi_call` was deprecated in JAX v0.4.34 (released Oct 4 2024), then added to the CHANGELOG in v0.4.35 (doh! released Oct 22). The JAX compatibility policy requires 3 months of compatible releases before a deprecation is finalized, so it is time to remove this parameter from the public API. The `vmap_method` parameter can be used instead, and the docs for [`pure_callback`](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) provide more details.
This change has one other (non-obvious!) affect on the user facing APIs. (Note that this change in behavior has also been protected by a deprecation warning since the `vectorized` parameter was deprecated.) The default behavior of `pure_callback` and `ffi_call` under `vmap` is to now raise an exception, rather than silently producing a loop. To opt in to the previous default behavior, use `vmap_method="sequential"`.
PiperOrigin-RevId: 747413383
The previous approach was to report, for several elements
of the cache key, the closest mismatch. Some parts of
the cache key were ignored, which led to "explanation unavailable".
The same happened when we had two keys close to the current
one, each differring in a different part of the key.
No explanation was produced because for each part of the key,
there was a matching key already in the cache, even though
the key taken as a whole did not match.
Now, we scan *all* parts of they key and compute the differences.
We keep track of the "size" of the differences, and we explain
the differences to those keys that are closest (possibly more
than one key if equidistant).
For example, for shape differences we'll report the
closest matching shape. If a type differs in both the dtype
and some parts of the shape, or sharding, it is considered
farther away.
We add new tests and explanations for different
static argnums and argnames.
There are still cases when we do not produce an explanation, but
now the "explanation unavailable" includes a description
of which component of the key is different, and what the
difference is. This may still be hard to understand by the
user but at least they can file a clearer bug.
Refactored the tests, and added a few new ones.
We no longer have many different implicit types conforming to `Lowering`, only `pxla.MeshComputation` and `pxla.PmapComputation`. Both are `XlaLowering` subtypes. So define just one common base class, call it `Lowering`, and inherit from just that in both concrete internal computation/lowering subtypes.
PiperOrigin-RevId: 746735857