This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.
If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.
PiperOrigin-RevId: 713296722
It would also be nice to optionally insert runtime assertions for dynamic
indices, but we don't have a way of doing that just yet.
PiperOrigin-RevId: 707532787
For now, the lowering only works for the strided fragmented layout. This is
mostly an exercise in plugging in lowering rules using `FragmentedArray`, and
will be expanded shortly.
PiperOrigin-RevId: 707031770
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.
At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).
Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.
We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.
Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.
When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.
PiperOrigin-RevId: 705092403
The barrier is a no-op at runtime, but appears as a side-effecting op to LLVM
which prevents it from moving the (even pure) computations that involve the
supplied arrays past the barrier.
PiperOrigin-RevId: 702709125
Instead of only allowing a fixed set of layouts that we've hand verified as
bank-conflict free, we now simulate the transactions performed within each
warp and verify that no bank conflicts happen. If we detect that the simple
schedule does not work out, we attempt to partition the threads in a warp
into two groups and stagger the transfers in a way that lets us avoid conflicts.
This allows us to match the hand-designed transfer schedule I wrote for 32-bit
types, and even generalizes it to more cases automatically (e.g. swizzle=32).
PiperOrigin-RevId: 701919158
This change removes the need to flatten the batch dimension into sequence dimensions
in the flash attention kernel. The critical thing here is the observation that we can
in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting
us reduce its rank when necessary.
Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU
lowering, which I've fixed.
PiperOrigin-RevId: 701035277
This corresponds to what's implemented in `BarrierRef`, and ultimately makes it
easier to allocate barriers at a specific address in dynamic shared memory.
PiperOrigin-RevId: 695308297
So far all of our layouts have been tailored to a limited set of use
cases we've tried so far, but they're still not general enough to
handle all of the register layouts needed for WGMMA or mixed precision
matmuls (incl. intermediate steps during conversions). Instead of adding
more special cases, I decided to adopt XLA tiled layouts and they do seem
to work quite well!
This change only lays the groundwork for the new layout system. Future
changes will build upon them to add new features and eventually replace
`WGMMA_LAYOUT` altogether.
PiperOrigin-RevId: 694105514
Previously we didn't really fully discharge squeezing the indexed
dims before applying other GMEM transforms, leading to potential
failures because they were not anticipating the increased rank.
PiperOrigin-RevId: 694098739
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.
PiperOrigin-RevId: 693241875
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.
PiperOrigin-RevId: 683663541
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".
We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.
Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.
PiperOrigin-RevId: 679563155