This factors out some logic from the apply-vector-layout shape cast rule where we insert a minor dimension, relaxes some offset restrictions on it, and uses it for the relayout.
PiperOrigin-RevId: 702993092
This cl removes all the shape constrains in matmul for all types.
We only need to mask out subelement on contracting dim. Instead of unpacking data and applying masks, we create a VREG-sized i32 "mask" which contains subelement mask info to logical and with target vreg. Through this way, in order to mask sub-elements, each target vreg only needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + packing).
PiperOrigin-RevId: 702480077
This allows users to distinguish Mosaic GPU kernels from other kernels
when using profiling programs such as Nsight Systems.
The new default behavior is to use `mosaic_gpu_<def_name>_kernel` as
the kernel name, where `<def_name>` is the name of the Mosaic GPU
Python kernel function passed to `as_gpu_kernel` or
`as_torch_gpu_kernel`.
We also add a new `kernel_name` optional argument to `as_gpu_kernel`
and `as_torch_gpu_kernel`. If `kernel_name` is not `None`, the
resulting kernel name is `mosaic_gpu_<kernel_name>_kernel`. This is
useful when the Mosaic GPU Python kernel function is constructed
through metaprogramming so that the final specialized kernel can have
different meaningful names depending on the metaparameters.
Previously the kernel name was always `main_kernel`.
This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]).
PiperOrigin-RevId: 698896373
This change:
- Bumps up the version of Mosaic to 4 in `serde.cc`.
- Adds optional `subcore_id` parameter to `tpu.sem_signal` for signalling specific subcores.
- Extends deserialization to correctly parse the older versions of Mosaic without the new parameter `subcore_id` of `tpu.sem_signal`.
PiperOrigin-RevId: 698163836
Adds the optional core type parameter to `tpu.sem_signal` for cross-core signalling.
If the target core type is not provided, the target core type is assumed to be that of the core issuing the signal.
The issuing core type is determined based on the core type annotation of the parent function; if the annotation is not provided, the issuing core type is assumed to be TensorCore.
PiperOrigin-RevId: 698129842
This cl introduces a general store op called tpu.vector_stores which aims to unify vector::store, tpu::strided_load, vector::masked_store. The tpu.vector_stores should also provide general interface for lowering for both TensorCore and SparseCore.
This cl also adds the support for (dynamic) masked store.
PiperOrigin-RevId: 698067741
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).
This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)
We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.
PiperOrigin-RevId: 697631402
This CLs adds parameter names to the optional parameters of `tpu.sem_signal` -- `device_id`, `core_id` -- to remove the ambiguity upon deserialization.
Adds LIT tests of signalling on TC with parameter names.
PiperOrigin-RevId: 695875037
Unaligned concat used to be f32 only, but implicitly protected via unimplemented support for multi-row-shift in sub32 types. When this was added, we started invoking unaligned concat flow w/ sub32 types, but the masking code that assumed full rows (unpacked types) was no longer sufficient - we need better granularity for these cases. This only affects sublanes, as that is where we pack, we don't have partial lanes.
This CL, as a small benefit, also adds better error messages to the ops involved in lower_to_llo.cc.
PiperOrigin-RevId: 695796095
As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.
PiperOrigin-RevId: 695694648
* Generalize any untiled memref to have tiling (packing, 128)
* Support dynamic index on 2nd minor.
* Support dynamic shape on 2nd minor.
PiperOrigin-RevId: 695516124
This corresponds to what's implemented in `BarrierRef`, and ultimately makes it
easier to allocate barriers at a specific address in dynamic shared memory.
PiperOrigin-RevId: 695308297
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.
PiperOrigin-RevId: 693241875