29 Commits

Author SHA1 Message Date
Benjamin Chetioui
d028354abb [Mosaic GPU] Introduce an initial transform inference pass.
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.

The pass isn't called anywhere yet.

PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Adam Paszke
6f7ce9d048 Skip ASAN tests for the big Mosaic GPU tests
They are timing out.

PiperOrigin-RevId: 735804647
2025-03-11 10:30:04 -07:00
Adam Paszke
bb96226dd8 [Mosaic GPU] Add support for small RHS tile sizes in WGMMA
This is useful for more fine-grained autotuning and can help avoid
wave quantization effects.

PiperOrigin-RevId: 732105219
2025-02-28 05:41:30 -08:00
Adam Paszke
c171fc6061 Disable XLA autotuning to speed up tests
PiperOrigin-RevId: 725559980
2025-02-11 03:25:10 -08:00
Peter Hawkins
f122f17b27 Rename test configs to include GPU variants more consistently.
* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration.
* Rename "_2gpu" test variants to "x2" variants, since this is more succinct.

This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run.

PiperOrigin-RevId: 715468944
2025-01-14 11:55:45 -08:00
Benjamin Chetioui
4ef7706abb [Mosaic GPU] Split layout inference and dialect lowering files and tests.
PiperOrigin-RevId: 705100503
2024-12-11 07:31:34 -08:00
jax authors
0d7eaeb5d8 Merge pull request #24805 from andportnoy:aportnoy/mosaic-gpu-cupti-profiler
PiperOrigin-RevId: 705071782
2024-12-11 05:29:10 -08:00
Andrey Portnoy
cc22334c21 [Mosaic GPU] Add CUPTI profiler alongside events-based implementation 2024-12-09 14:31:20 -05:00
Adam Paszke
6a124ac554 [Mosaic GPU] Implement tiled and swizzled transfers for tiled layouts
PiperOrigin-RevId: 694449664
2024-11-08 04:38:20 -08:00
Benjamin Chetioui
c708a04c6e [Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
Also start moving the existing C++ tests to Python.

PiperOrigin-RevId: 691729887
2024-10-31 02:47:30 -07:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Sergei Lebedev
8159d3352c Updated :gpu_test configuration
PiperOrigin-RevId: 674242448
2024-09-13 04:24:09 -07:00
Sergei Lebedev
ea68f4569c Internal change
PiperOrigin-RevId: 673409076
2024-09-11 08:47:58 -07:00
Adam Paszke
4c3111bf26 [Mosaic GPU] Unbreak tests
I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially
seemed right because both expressions evalute to 2 for 1 :)

PiperOrigin-RevId: 670527107
2024-09-03 06:07:54 -07:00
Peter Hawkins
cd20404159 Disable mosaic gpu tests that are failing at head.
PiperOrigin-RevId: 669390680
2024-08-30 11:31:09 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Adam Paszke
9c3f2dcefc [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest
XLA runtime creates a context per device, so we need to make sure that a kernel is loaded
separately on each device.

PiperOrigin-RevId: 666353098
2024-08-22 08:06:37 -07:00
Adam Paszke
ca6be2573b [Mosaic GPU] Move matmul tests to Hypothesis
We've been generating thousands of test cases and that's just not
scalable. Hypothesis should let us efficiently explore a large
number of configurations.

PiperOrigin-RevId: 662447113
2024-08-13 03:21:51 -07:00
Benjamin Chetioui
25a47649d2 [Mosaic GPU] Change FlashAttention implementation to support Grouped Query Attention.
Also add tests in `flash_attention_test.py`.

PiperOrigin-RevId: 642626612
2024-06-12 08:46:06 -07:00
Adam Paszke
a7e35c6b9a [Mosaic GPU] Move the matmul example runner away from the test harness
This just makes more sense. It really shouldn't be a jax_test beacause it doesn't
even import test_util.

PiperOrigin-RevId: 639872888
2024-06-03 12:23:31 -07:00
Benjamin Chetioui
5aec259dc7 [Mosaic GPU] Implement basic support for aliasing shared memory.
When the user constructs the relevant shapes that live in `smem` throughout the
program, they now have the possibility of using a `mosaic_gpu.Union` of PyTrees
instead of a single `PyTree`.

`mosaic_gpu.Union` allows declaring several sets of buffers where within the
set, the buffers are alive concurrently, but between two distinct sets, the
buffers are alive for non-intersecting time intervals.

PiperOrigin-RevId: 636533045
2024-05-23 06:40:34 -07:00
Adam Paszke
53ec2cd26f Add notap tag to Mosaic tests
PiperOrigin-RevId: 635379982
2024-05-20 01:35:56 -07:00
Adam Paszke
a527b71970 [Mosaic GPU] Prepare for writing warp-specialized kernels
PiperOrigin-RevId: 632854287
2024-05-11 17:09:08 -07:00
Adam Paszke
8fd9c2f160 [Mosaic GPU] Add the flash attention example
PiperOrigin-RevId: 629092401
2024-04-29 09:31:30 -07:00
Adam Paszke
4c3d4323dd [Mosaic GPU] Disable matmul tests in internal CI
PiperOrigin-RevId: 628379779
2024-04-26 05:50:35 -07:00
Adam Paszke
c6ca1ef204 [Mosaic GPU] Add the first example: pipelined matmul
PiperOrigin-RevId: 628156068
2024-04-25 12:27:25 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00