71 Commits

Author SHA1 Message Date
Sergei Lebedev
a373e37be2 Fixed mgpu.FragmentedArray.reduce_sum for integer types
The implementation previously assumed the type is floating and used addf.

PiperOrigin-RevId: 678718871
2024-09-25 08:50:24 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Sergei Lebedev
8196c8bf36 Added support for % and select to mgpu.FragmentedArray
PiperOrigin-RevId: 678200940
2024-09-24 05:19:25 -07:00
Sergei Lebedev
1256e18fd4 Added comparison operators to mgpu.FragmentedArray
PiperOrigin-RevId: 677788023
2024-09-23 07:37:53 -07:00
Sergei Lebedev
f311e81c02 Added is_signed to mgpu.FragmentedArray
The registers within a fragmented array always use signless types, and instead
the signedness is tracked on the fragmented arrays itself (i.e. in Python).

PiperOrigin-RevId: 677776009
2024-09-23 06:59:41 -07:00
Adam Paszke
81b8b4b7b4 [Mosaic GPU] Clean up the module structure
Previously the code was awkwardly split between the `jax.experimental.mosaic.gpu`
and `jax.experimental.mosaic.gpu.dsl` namespaces. I've now merged both so that
all user-visible APIs are accessible from `jax.experimental.mosaic.gpu`.

PiperOrigin-RevId: 676857257
2024-09-20 08:42:13 -07:00
jax authors
4e6f690724 Merge pull request #23653 from apaszke:torchsaic
PiperOrigin-RevId: 675967844
2024-09-18 06:35:15 -07:00
Adam Paszke
611ad63060 Add basic PyTorch integration for Mosaic GPU
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
2024-09-18 12:55:23 +00:00
Sergei Lebedev
427a490d2b Ported a few changes to FragmentArray by cperivol@
* It now supports unary negation
* and pointwise operations between scalars and FragmentedArrays

PiperOrigin-RevId: 674244294
2024-09-13 04:32:37 -07:00
Sergei Lebedev
8159d3352c Updated :gpu_test configuration
PiperOrigin-RevId: 674242448
2024-09-13 04:24:09 -07:00
Sergei Lebedev
ea68f4569c Internal change
PiperOrigin-RevId: 673409076
2024-09-11 08:47:58 -07:00
Adam Paszke
4c3111bf26 [Mosaic GPU] Unbreak tests
I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially
seemed right because both expressions evalute to 2 for 1 :)

PiperOrigin-RevId: 670527107
2024-09-03 06:07:54 -07:00
Peter Hawkins
cd20404159 Disable mosaic gpu tests that are failing at head.
PiperOrigin-RevId: 669390680
2024-08-30 11:31:09 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Adam Paszke
be59f6ec47 [Mosaic GPU] Support tiled stores of arrays with fewer columns than swizzling
PiperOrigin-RevId: 666798285
2024-08-23 08:06:25 -07:00
Adam Paszke
f54e220430 [Mosaic GPU] Add support for short n dimension in WGMMA
PiperOrigin-RevId: 666766079
2024-08-23 06:08:37 -07:00
Adam Paszke
9c3f2dcefc [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest
XLA runtime creates a context per device, so we need to make sure that a kernel is loaded
separately on each device.

PiperOrigin-RevId: 666353098
2024-08-22 08:06:37 -07:00
Adam Paszke
0b4f64e002 [Mosaic GPU] Allow tile sizes to exceed dimension size
Otherwise, the dimension size still needs to be a multiple of tiling.

PiperOrigin-RevId: 666298624
2024-08-22 04:59:11 -07:00
Adam Paszke
ce3ea109a4 [Mosaic GPU] Add a fast type conversion from s8 vectors to bf16 vectors
Regular conversion instructions have a ridiculously low throughput on Hopper,
so replacing them with some bit tricks yields a much faster implementation.

Co-authored-by: Benjamin Chetioui <bchetioui@google.com>
PiperOrigin-RevId: 665893696
2024-08-21 08:39:24 -07:00
Adam Paszke
2ab7558425 [Mosaic GPU] Add support for grid tiling to improve L2 cache utilization
While CUDA technically does not guarantee anything about the order in
which blocks will be executed, in practice they are generally scheduled
in column-major order within the grid. We can use this property to launch
the blocks in a tiled way, which can lead to an improved rate of L2 hits
and a significant performance boost.

PiperOrigin-RevId: 662834982
2024-08-14 02:17:55 -07:00
Adam Paszke
f384497f68 [Mosaic GPU] Add support for cluster collective loads and barriers over multiple dimensions
This will be useful for an upcoming change to the matmul kernel that splits the N blocks
over two cluster dimensions.

PiperOrigin-RevId: 662825455
2024-08-14 01:47:12 -07:00
Adam Paszke
f4c0b1feb0 [Mosaic GPU] Add control over the output format in the matmul example
PiperOrigin-RevId: 662478648
2024-08-13 05:33:12 -07:00
Adam Paszke
5cf89b3f61 [Mosaic GPU] Add support for various swizzles in the matmul example
PiperOrigin-RevId: 662459766
2024-08-13 04:12:43 -07:00
Adam Paszke
ca6be2573b [Mosaic GPU] Move matmul tests to Hypothesis
We've been generating thousands of test cases and that's just not
scalable. Hypothesis should let us efficiently explore a large
number of configurations.

PiperOrigin-RevId: 662447113
2024-08-13 03:21:51 -07:00
Christos Perivolaropoulos
cd4e91b2b0 [mosaic_gpu] Store untiled splat layout
PiperOrigin-RevId: 662077826
2024-08-12 07:34:07 -07:00
Sergei Lebedev
28ca734d9b Added another boxDim check to mosaic_gpu_init_tma_desc
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
Adam Paszke
d862f78dcc [Mosaic GPU] Skip matmul tests with large clusters
I'm still investigating but they sometimes hang for an unclear reason.

PiperOrigin-RevId: 656426326
2024-07-26 09:21:13 -07:00
Adam Paszke
2e6da35e97 [Mosaic GPU] Add support for clusters in the matmul example
With the collective async_copy API, the changes are quite minimal!

PiperOrigin-RevId: 655937185
2024-07-25 06:46:51 -07:00
Adam Paszke
e59303cf3e [Mosaic GPU] Simplify the matmul example
Remove a bunch of WGMMAImpl classes. This is meant to be a simple forkable example,
not a complete kernel.

PiperOrigin-RevId: 655923069
2024-07-25 05:43:57 -07:00
Adam Paszke
4f19af911c [Mosaic GPU] Only split collective TMAs only (multiple) major dimensions
Each TMA only writes to a contiguous subset of SMEM, so skipping a major
dimension while splitting results in incorrect code. To work around the
loss of flexibility, we now allow splitting multiple leading dimensions
to handle larger clusters and tiled references.

PiperOrigin-RevId: 655700486
2024-07-24 14:26:07 -07:00
Adam Paszke
dbe8f56353 [Mosaic GPU] Strengthen cluster-related tests by covering more cluster shapes
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.

PiperOrigin-RevId: 655686102
2024-07-24 13:43:52 -07:00
Adam Paszke
e52dc7ed15 [Mosaic GPU] Move barrier allocation to SMEM scratch specs
This is slightly less convenient than our previous approach but it has two main upsides:
1. It lets us automatically emit necessary fences and barriers for use with block clusters
2. It lets us share the same block/cluster barrier for all initializations of mbarriers

This change also moves away from the nvgpu dialect for barriers and allocates them in
dynamic SMEM instead of relying on static SMEM. This should give us more control over
SMEM layouts and alignments, and simplifies the lowering process.

PiperOrigin-RevId: 655493451
2024-07-24 02:56:52 -07:00
Adam Paszke
6bc7929376 [Mosaic GPU] Add sin/cos + unify support for approximate transcendental functions
PiperOrigin-RevId: 655469213
2024-07-24 01:15:57 -07:00
Adam Paszke
f0792b2d77 [Mosaic GPU] Add a collective mbarrier interface
Memory barriers are necessary to prevent excessive run ahead in a collective
pipeline, but the implementation can be tricky (both in terms of calculating
the right arrival count and dividing the signalling responsibility between
threads). I largely tried to follow the practices that CUTLASS established,
although I still do not understand why it swizzles the cluster for signalling.

PiperOrigin-RevId: 655098234
2024-07-23 03:19:29 -07:00
Adam Paszke
51732c5caf [Mosaic GPU] Replace multicast_mask by a nicer collective async copy interface
Instead of asking the user to compute the transfer size, manually slice up the
transfer and compute and specify the multicast mask, we fold all that functionality
into the `async_copy` function. The copy should be called by all blocks in a given
cluster slice along the specified dimension, and will collectively load all the
requested data into all blocks in that slice.

PiperOrigin-RevId: 655077439
2024-07-23 01:55:14 -07:00
Adam Paszke
a2b2fbf513 [Mosaic GPU] Add early support for block clusters and multicast TMA
PiperOrigin-RevId: 655057490
2024-07-23 00:50:20 -07:00
Adam Paszke
d8f435094d [Mosaic GPU] Add support for tiled loads and stores with swizzles other than 128
We have correctness tests in CI, but I additionally ran them under ncu to verify that
we never cause bank conflicts.

PiperOrigin-RevId: 653930174
2024-07-19 02:06:52 -07:00
Adam Paszke
a07b9adcb2 [Mosaic GPU] Add support for WGMMA lhs in registers for swizzles other than 128
PiperOrigin-RevId: 653626991
2024-07-18 08:23:16 -07:00
Adam Paszke
ade76f09b1 [Mosaic GPU] Support narrower swizzles in copy and TMA tests
PiperOrigin-RevId: 649045134
2024-07-03 05:58:54 -07:00
Adam Paszke
b19ad5b315 [Mosaic GPU] Add support for non-128B swizzles in WGMMA
PiperOrigin-RevId: 647667550
2024-06-28 07:12:10 -07:00
Adam Paszke
3ebebdfb76 [Mosaic GPU] Stop using nvgpu for TMA
It seems like nvgpu dialect bakes in a bunch of overly restrictive checks in its verifiers
and doesn't really buy us much in this case. nvvm works just fine.

PiperOrigin-RevId: 647653684
2024-06-28 06:08:36 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
Adam Paszke
f976f1f224 [Mosaic GPU] Use explicit WGMMA/ALU scheduling in the flash attention kernel
With this change we reach state of the art performance (as far as I can tell)
of 50%+ TC util for head_dim 128 and 256.

I also added a little tuning harness to try out different block sizes.

PiperOrigin-RevId: 644927079
2024-06-20 00:56:44 -07:00
Benjamin Chetioui
25a47649d2 [Mosaic GPU] Change FlashAttention implementation to support Grouped Query Attention.
Also add tests in `flash_attention_test.py`.

PiperOrigin-RevId: 642626612
2024-06-12 08:46:06 -07:00
Adam Paszke
3b4039c850 [Mosaic GPU] Load LLVM lowering interfaces for all dialects
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.

PiperOrigin-RevId: 641874183
2024-06-10 05:55:01 -07:00
Jake VanderPlas
a2c31f4d15 pallas/mosaic test: avoid leaking global config state 2024-06-06 16:00:02 -07:00
Adam Paszke
a7e35c6b9a [Mosaic GPU] Move the matmul example runner away from the test harness
This just makes more sense. It really shouldn't be a jax_test beacause it doesn't
even import test_util.

PiperOrigin-RevId: 639872888
2024-06-03 12:23:31 -07:00