The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286
The registers within a fragmented array always use signless types, and instead
the signedness is tracked on the fragmented arrays itself (i.e. in Python).
PiperOrigin-RevId: 677776009
Previously the code was awkwardly split between the `jax.experimental.mosaic.gpu`
and `jax.experimental.mosaic.gpu.dsl` namespaces. I've now merged both so that
all user-visible APIs are accessible from `jax.experimental.mosaic.gpu`.
PiperOrigin-RevId: 676857257
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially
seemed right because both expressions evalute to 2 for 1 :)
PiperOrigin-RevId: 670527107
Regular conversion instructions have a ridiculously low throughput on Hopper,
so replacing them with some bit tricks yields a much faster implementation.
Co-authored-by: Benjamin Chetioui <bchetioui@google.com>
PiperOrigin-RevId: 665893696
While CUDA technically does not guarantee anything about the order in
which blocks will be executed, in practice they are generally scheduled
in column-major order within the grid. We can use this property to launch
the blocks in a tiled way, which can lead to an improved rate of L2 hits
and a significant performance boost.
PiperOrigin-RevId: 662834982
We've been generating thousands of test cases and that's just not
scalable. Hypothesis should let us efficiently explore a large
number of configurations.
PiperOrigin-RevId: 662447113
Each TMA only writes to a contiguous subset of SMEM, so skipping a major
dimension while splitting results in incorrect code. To work around the
loss of flexibility, we now allow splitting multiple leading dimensions
to handle larger clusters and tiled references.
PiperOrigin-RevId: 655700486
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.
PiperOrigin-RevId: 655686102
This is slightly less convenient than our previous approach but it has two main upsides:
1. It lets us automatically emit necessary fences and barriers for use with block clusters
2. It lets us share the same block/cluster barrier for all initializations of mbarriers
This change also moves away from the nvgpu dialect for barriers and allocates them in
dynamic SMEM instead of relying on static SMEM. This should give us more control over
SMEM layouts and alignments, and simplifies the lowering process.
PiperOrigin-RevId: 655493451
Memory barriers are necessary to prevent excessive run ahead in a collective
pipeline, but the implementation can be tricky (both in terms of calculating
the right arrival count and dividing the signalling responsibility between
threads). I largely tried to follow the practices that CUTLASS established,
although I still do not understand why it swizzles the cluster for signalling.
PiperOrigin-RevId: 655098234
Instead of asking the user to compute the transfer size, manually slice up the
transfer and compute and specify the multicast mask, we fold all that functionality
into the `async_copy` function. The copy should be called by all blocks in a given
cluster slice along the specified dimension, and will collectively load all the
requested data into all blocks in that slice.
PiperOrigin-RevId: 655077439
It seems like nvgpu dialect bakes in a bunch of overly restrictive checks in its verifiers
and doesn't really buy us much in this case. nvvm works just fine.
PiperOrigin-RevId: 647653684
With this change we reach state of the art performance (as far as I can tell)
of 50%+ TC util for head_dim 128 and 256.
I also added a little tuning harness to try out different block sizes.
PiperOrigin-RevId: 644927079
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.
PiperOrigin-RevId: 641874183