- Make internal scratch size configurable.
- Pass the number of max sublanes allowed in scratch to apply-vector-layout pass.
- Create a helper function to fetch internal scratch VMEM address.
PiperOrigin-RevId: 644184896
a multi_dim_reduction with f32 source/accumulator/output, where the source
and accumulator are extended and the result is truncated. This addressed 'only
32-bit reductions supported' error.
PiperOrigin-RevId: 644062786
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.
Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.
For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays
PiperOrigin-RevId: 643972675
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 642954763
contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.
PiperOrigin-RevId: 642051952
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block.
- Refactor the logic of assume layouts for block arguments to a helper function.
- Add tests for nested fori loop and while loop.
PiperOrigin-RevId: 641973011
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.
PiperOrigin-RevId: 641874183
- Support non-zero minor offsets without having to relayout (they're still a no-op).
- Remove restriction on tiling which now allows 1D packed types to work.
PiperOrigin-RevId: 640967375
- On infer-vector-layout remove some restrictions related to batch dimensions. Reshaping them doesn't matter as long as they don't combine with tiled dimensions.
- On apply-vector-layout, simplify handling of cases where the implicit tiled don't change while removing some unnecessary restrictions.
- Don't require native tiling or natural topology for this.
PiperOrigin-RevId: 640837740
This re-enables the tests removed in https://github.com/google/jax/pull/21563
and adds support for exposing the XLA FFI headers in the
`jax.extend.ffi.include_dir` directory during a bazel build. While it's
unlikely that these will be useful for most bazel users, it is good to provide
a consistent interface with the wheel build and to be able to test this feature.
PiperOrigin-RevId: 640194961
Previously the rule would complain if the layouts were unsupported, but that's not
the right way to handle that situation. With this change, we simply pick a supported
configuration instead (and expect relayout to handle it).
PiperOrigin-RevId: 640190248
Pallas hasn't been using it to lower since cl/604849222, which is before we had serde, so we won't create a serde rule for this. That CL also tried to remove the support, but it had to be restored in cl/605436599 because it was breaking serialized kernels.
PiperOrigin-RevId: 639896981
We should infer layout for each terminator inside its own region and find a compatible layout for a final result if the result is based on terminators from multiple regions like scf::ifOp, scf::whileOp, scf::forOp. If no compatible layout is found, we will fall back to a normalized layout. Finally we also need to ensure the layouts in input, terminator and output are consistent across loops.
PiperOrigin-RevId: 639122434
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750
MacOS users frequently encounter a run-time error related to unsupported
AVX instructions (#21491, most recently). This is typically caused by
running an x86 Python build on ARM hardware, and it requires
re-installation of Python. This PR adds this suggestion to the error
message when encountered on macOS.
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.
PiperOrigin-RevId: 638132572
- Enable it for minor or second-minor implicit dims for the non-no-op case.
- Don't allow output offsets for broadcasted dimensions to be non-replicated. Make sure to assign them as replicated in infer-vector-layout for all cases.
- Don't fail when both tiled dimensions are logically broadcasted but only one of them requires actual broadcasting (before, it would hit the unimplemented sublane + lane broadcast case).
PiperOrigin-RevId: 637772134
The implementation uses the new tpu.log operation in the Mosaic TPU dialect.
Note that
* the logging only happens if --xla_tpu_enable_log_recorder is set;
* only scalars can be printed;
* placeholders only accept i32 arguments at the moment.
PiperOrigin-RevId: 636585852