49 Commits

Author SHA1 Message Date
jax authors
f3b7c5cb9e Integrate LLVM at llvm/llvm-project@0230d63b4a
Updates LLVM usage to match
[0230d63b4a8b](https://github.com/llvm/llvm-project/commit/0230d63b4a8b)

PiperOrigin-RevId: 738222096
2025-03-18 19:23:20 -07:00
Adam Paszke
8da93249d2 [Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.

PiperOrigin-RevId: 737967890
2025-03-18 05:38:49 -07:00
Chris Jones
38d52a19ef [mosaic_gpu] Force flush all cupti activity, then unsubscribe.
With default flushing, it is possible for events to be missed. We should only unsubscribe after we are finished with cupti.

PiperOrigin-RevId: 737939327
2025-03-18 03:35:03 -07:00
H. Vetinari
91cae595e4 fix member access to packed CUDA struct 2025-02-24 08:03:07 +11:00
jax authors
725087e13f Integrate LLVM at llvm/llvm-project@9d24f94379
Updates LLVM usage to match
[9d24f9437944](https://github.com/llvm/llvm-project/commit/9d24f9437944)

PiperOrigin-RevId: 728265165
2025-02-18 10:30:48 -08:00
jax authors
e78a469b42 Integrate LLVM at llvm/llvm-project@912b154f3a
Updates LLVM usage to match
[912b154f3a3f](https://github.com/llvm/llvm-project/commit/912b154f3a3f)

PiperOrigin-RevId: 727895384
2025-02-17 10:08:37 -08:00
Dimitar (Mitko) Asenov
a3a285dddc [Mosaic GPU] Handle the swizzle attribute in the lowering of async_store and async_load
PiperOrigin-RevId: 720129408
2025-01-27 05:18:16 -08:00
Sergei Lebedev
9ee7123c39 [mosaic_gpu] Fixed mosaic_gpu-serde pass registration
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719280601
2025-01-24 06:35:54 -08:00
Adam Paszke
7043b852ec [Mosaic GPU] Add basic support for TMA with sub-byte types
PiperOrigin-RevId: 719240287
2025-01-24 03:54:12 -08:00
Peter Hawkins
034e967e11 Remove CUDA rpaths from jaxlib build.
These are also set in the TSL build rules as part of the CUDA stub libraries, which these libraries depend on, so these copies of the rpath settings are redundant.

PiperOrigin-RevId: 716844265
2025-01-17 17:09:30 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Peter Hawkins
90d8f37863 Rename pybind_extension to nanobind_extension.
We have no remaining uses of pybind11 outside a GPU custom call example.

PiperOrigin-RevId: 712608834
2025-01-06 11:53:44 -08:00
Benjamin Chetioui
36b12d58f4 [Mosaic GPU] Add end-to-end lowering example for a pointwise kernel using the dialect and layout inference.
Also implement a lowering rule for `arith.AddFOp`.

PiperOrigin-RevId: 707131747
2024-12-17 09:28:05 -08:00
Sergei Lebedev
a14e6968bf [mosaic] Migrated the serialization pass from codegen to pass_boilerplate.h
This prepares teh generalization of the serialization pass to handle both
Mosaic TPU and GPU.

PiperOrigin-RevId: 705628923
2024-12-12 14:19:36 -08:00
jax authors
0d7eaeb5d8 Merge pull request #24805 from andportnoy:aportnoy/mosaic-gpu-cupti-profiler
PiperOrigin-RevId: 705071782
2024-12-11 05:29:10 -08:00
Andrey Portnoy
cc22334c21 [Mosaic GPU] Add CUPTI profiler alongside events-based implementation 2024-12-09 14:31:20 -05:00
Andrey Portnoy
7bd81dbe0d [Mosaic GPU] Improve default kernel name and add option to customize
This allows users to distinguish Mosaic GPU kernels from other kernels
when using profiling programs such as Nsight Systems.

The new default behavior is to use `mosaic_gpu_<def_name>_kernel` as
the kernel name, where `<def_name>` is the name of the Mosaic GPU
Python kernel function passed to `as_gpu_kernel` or
`as_torch_gpu_kernel`.

We also add a new `kernel_name` optional argument to `as_gpu_kernel`
and `as_torch_gpu_kernel`. If `kernel_name` is not `None`, the
resulting kernel name is `mosaic_gpu_<kernel_name>_kernel`. This is
useful when the Mosaic GPU Python kernel function is constructed
through metaprogramming so that the final specialized kernel can have
different meaningful names depending on the metaparameters.

Previously the kernel name was always `main_kernel`.
2024-12-02 22:22:11 -05:00
Sergei Lebedev
d304025a41 [mosaic_gpu] The profiler now uses FFI calls for creating events and computing elapsed time
PiperOrigin-RevId: 695798787
2024-11-12 11:01:59 -08:00
Adam Paszke
ce3826d098 [Mosaic GPU] Make sure to free the cloned MLIR module when debugging
We only recently started using this in tests and it has caused ASAN
to report a bunch of leaks.

PiperOrigin-RevId: 694510867
2024-11-08 08:35:10 -08:00
Sergei Lebedev
34b4787e2e [mosaic_gpu] Check the return code of gpuEventCreate and gpuEventDestroy
PiperOrigin-RevId: 693260326
2024-11-05 01:59:58 -08:00
Adam Paszke
36c56fa19b [Pallas:MGPU] Fix flaky debug_print tests
Turns out that waiting for the kernel to finish it not enough, since the
prints also need to be processed by the CUDA runtime. Using a test-only
function that synchronizes all the devices seems to suffice.

PiperOrigin-RevId: 690624999
2024-10-28 08:42:02 -07:00
Sergei Lebedev
04bdd07f66 [mosaic_gpu] mgpu.FragmentedArray now supports //
This is needed to compute grid index from the iteration step counter in `emit_pipeline`.

PiperOrigin-RevId: 690608581
2024-10-28 07:52:22 -07:00
Adam Paszke
6634f5a348 [Mosaic GPU] Use absl::StrCat instead std::string::operator+
Repeated string addition is apparently a bit of an anti-pattern. Not that it matters
much in this place, but why not do it properly.

PiperOrigin-RevId: 689416587
2024-10-24 09:49:51 -07:00
Andrey Portnoy
14e0f0e7fa [Mosaic GPU] Query SM and PTX ISA dynamically using driver and LLVM
Originally proposed in #24021. Slightly rewritter to make testing with internal LLVM toolchains better.

Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`).
Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`).
Then use LLVM to map the SM string to a PTX ISA version that supports the SM.

Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
PiperOrigin-RevId: 689286774
2024-10-24 01:46:29 -07:00
Adam Paszke
611ad63060 Add basic PyTorch integration for Mosaic GPU
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
2024-09-18 12:55:23 +00:00
Adam Paszke
8feab68209 [Mosaic GPU] Remove the unnecessary scratch space operand
And clean up the C++ dispatch code. We don't use HBM scratch anymore
since we pass TMA descriptors as kernel arguments.

PiperOrigin-RevId: 671327420
2024-09-05 04:57:52 -07:00
Sergei Lebedev
7dd9adba05 Fixed stack-use-after-scope in Mosaic GPU
PiperOrigin-RevId: 668958750
2024-08-29 09:07:58 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Adam Paszke
9c3f2dcefc [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest
XLA runtime creates a context per device, so we need to make sure that a kernel is loaded
separately on each device.

PiperOrigin-RevId: 666353098
2024-08-22 08:06:37 -07:00
Sergei Lebedev
28ca734d9b Added another boxDim check to mosaic_gpu_init_tma_desc
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
Sergei Lebedev
5e418f5ab2 Added argument validation to mosaic_gpu_init_tma_desc
This should help with understanding cuTensorMapEncodeTiled failures, since
CUDA doesn't provide any details beyond the error return code.

Note that this change also ensures that TMA descriptors are 64-byte aligned.

PiperOrigin-RevId: 656062820
2024-07-25 13:16:34 -07:00
Adam Paszke
dbe8f56353 [Mosaic GPU] Strengthen cluster-related tests by covering more cluster shapes
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.

PiperOrigin-RevId: 655686102
2024-07-24 13:43:52 -07:00
Adam Paszke
a2b2fbf513 [Mosaic GPU] Add early support for block clusters and multicast TMA
PiperOrigin-RevId: 655057490
2024-07-23 00:50:20 -07:00
Adam Paszke
265a54da31 [Mosaic GPU] Pass in TMA descriptors through kernel parameters
As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.

PiperOrigin-RevId: 648744363
2024-07-02 09:30:52 -07:00
Christos Perivolaropoulos
ea49194926 [msoaic_gpu] Control dumping mlir with MOSAIC_GPU_DUMP_MLIR_PASSES
PiperOrigin-RevId: 647341364
2024-06-27 09:17:52 -07:00
Chris Jones
de8fd3b00d [mosaic:gpu] Fix MLIR canonicalization pass region-simplify option.
`region-simplify` now has `normal` and `aggressive` modes (using `normal` for now).

PiperOrigin-RevId: 644724434
2024-06-19 06:02:11 -07:00
Adam Paszke
4ea73bf787 Use constant memory to pass in TMA descriptors to the kernel
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.

Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.

For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays

PiperOrigin-RevId: 643972675
2024-06-17 05:31:26 -07:00
jax authors
d20b9e324f Integrate LLVM at llvm/llvm-project@8c5d9c79b9
Updates LLVM usage to match
[8c5d9c79b96e](https://github.com/llvm/llvm-project/commit/8c5d9c79b96e)

PiperOrigin-RevId: 642352474
2024-06-11 12:24:43 -07:00
Adam Paszke
1256ceb266 [Mosaic GPU] Rearrange the pass pipeline (again)
PiperOrigin-RevId: 642256145
2024-06-11 06:59:50 -07:00
Adam Paszke
0739d520b1 [Mosaic GPU] Don't always run with llvm::DebugFlag enabled
This slipped past during code review.

PiperOrigin-RevId: 641899993
2024-06-10 07:50:26 -07:00
Adam Paszke
3b4039c850 [Mosaic GPU] Load LLVM lowering interfaces for all dialects
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.

PiperOrigin-RevId: 641874183
2024-06-10 05:55:01 -07:00
Adam Paszke
d01496a09a [Mosaic GPU] Restore the PTX/PTXAS/SASS dump flags
They're very useful while prototyping the kernels.

PiperOrigin-RevId: 639027506
2024-05-31 07:27:36 -07:00
Sergei Lebedev
d2a39bc61b Updated the layer norm implementation in Mosaic GPU tests
jnp.var now needs lax.gt_p, which we don't currently support.

PiperOrigin-RevId: 639011383
2024-05-31 06:11:48 -07:00
Sergei Lebedev
8729952d82 Added a missing return to MosaicGPUCustomCall
PiperOrigin-RevId: 638627696
2024-05-30 06:13:01 -07:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Adam Paszke
32cb7c3f94 [Mosaic GPU] Stop using the MLIR CUDA runtime
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.

PiperOrigin-RevId: 629069842
2024-04-29 08:04:51 -07:00
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00