VectorLayout offsets are now allowed to fall anywhere within the vreg slice. This way, tiling is still applied after offsets and offsets are still applied after implicit dimensions.
Note that offsets outside of the vreg slice would mean a vreg full of padding, which is why we disallow them.
PiperOrigin-RevId: 650408597
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 650212574
We support any dynamic index on 2nd minor dim in either of the cases:
1. The minormost dim size of a unsliced memref matches VREG lane count.
2. Load/store one row on the second minormost dim, which triggers implicit strided load/store.
Note: For the default cases which can not skip the alignment check, we still use dynamic slice + static load/store solution to reduce scalar core work. We should figure out a way to optimize this in all cases.
PiperOrigin-RevId: 648771794
As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.
PiperOrigin-RevId: 648744363
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```
It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.
The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.
PiperOrigin-RevId: 647993991
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.
For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.
PiperOrigin-RevId: 647657084
When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing.
Instead, it should be safer to handle mapping failures to ffi::Error manually.
PiperOrigin-RevId: 647348889
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.
PiperOrigin-RevId: 647343656
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.
PiperOrigin-RevId: 647328834
In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free.
PiperOrigin-RevId: 645169743
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 644845277
We will choose the best solution based on the size of internal scratch memory.
- Sol 1: Convert dynamic roll to Log(N) static ops
- Sol 2: Static Store + Dynamic Load with internal scratch
PiperOrigin-RevId: 644509328
- Make internal scratch size configurable.
- Pass the number of max sublanes allowed in scratch to apply-vector-layout pass.
- Create a helper function to fetch internal scratch VMEM address.
PiperOrigin-RevId: 644184896
a multi_dim_reduction with f32 source/accumulator/output, where the source
and accumulator are extended and the result is truncated. This addressed 'only
32-bit reductions supported' error.
PiperOrigin-RevId: 644062786
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.
Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.
For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays
PiperOrigin-RevId: 643972675
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 642954763