4315 Commits

Author SHA1 Message Date
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Benjamin Chetioui
837418c652 [Mosaic GPU] Remove old jaxlib version guards.
PiperOrigin-RevId: 726071956
2025-02-12 08:49:40 -08:00
Benjamin Chetioui
c7199fe8a5 [Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.

This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.

PiperOrigin-RevId: 726030853
2025-02-12 06:29:25 -08:00
jax authors
1e2a5770c9 Merge pull request #26455 from gnecula:debug_info_jaxpr_8
PiperOrigin-RevId: 726023315
2025-02-12 06:03:32 -08:00
George Necula
faa0ad6f33 [better_errors] Continue adding debug info to Jaxprs (step 8)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
2025-02-12 14:23:52 +01:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Adam Paszke
e987ce2b77 [Mosaic GPU] Add autotuning to the Blackwell matmul kernel
PiperOrigin-RevId: 725726073
2025-02-11 12:17:46 -08:00
Dimitar (Mitko) Asenov
6fc1c61520 [Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
2025-02-11 11:51:25 -08:00
Adam Paszke
c2bd1576da [Mosaic GPU] Add support for tiling the M grid dimension in Blackwell matmul
This lets us make better use of L2.

PiperOrigin-RevId: 725715193
2025-02-11 11:47:08 -08:00
Adam Paszke
70007471c7 [Mosaic GPU] Use union to avoid excessive SMEM usage in the Blackwell matmul
PiperOrigin-RevId: 725667871
2025-02-11 09:48:23 -08:00
Adam Paszke
74e86bab26 [Mosaic GPU] Add support for collective MMA in the Blackwell matmul example
PiperOrigin-RevId: 725630722
2025-02-11 07:57:53 -08:00
Adam Paszke
21598d02e5 [Mosaic GPU] Add support for non-multicast .cta_group::2 async_copies
This instruction is particularly useful for collective MMA, since it lets us
easily report on the progress of async copies from both blocks in the single
block that will be performing the MMA.

PiperOrigin-RevId: 725618793
2025-02-11 07:13:35 -08:00
Adam Paszke
849ea268a1 [Mosaic GPU] Add support for 2-CTA MMA on Blackwell
PiperOrigin-RevId: 725600651
2025-02-11 06:06:47 -08:00
Adam Paszke
0209eee185 [Mosaic GPU] Handle TMEM allocation in the compiler
The code for allocation is uninteresting and it's the only set of primitives
that is executed by a single warp (other TMA APIs have single-thread or
warpgroup issue granularity).

PiperOrigin-RevId: 725583720
2025-02-11 05:01:25 -08:00
George Necula
550d1aa187 [better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
2025-02-11 11:28:58 +01:00
Peter Hawkins
1e447c8ad2 Fix Python version test against Python 3.12.
PiperOrigin-RevId: 725413928
2025-02-10 17:54:12 -08:00
jax authors
b7d012281e Merge pull request #26423 from gnecula:debug_info_jaxpr_7
PiperOrigin-RevId: 725317552
2025-02-10 12:58:26 -08:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
Adam Paszke
3b5e91b8a8 [Mosaic GPU] Add tests for various tcgen05.mma configurations
It would be good to add smaller tests that verify reads and writes to TMEM,
since we depend on it here, but that will come later.

PiperOrigin-RevId: 724328602
2025-02-07 06:50:11 -08:00
Adam Paszke
6524e67fd4 [Mosaic GPU] Add more SMEM buffers to avoid blocking for memory traffic
PiperOrigin-RevId: 723969131
2025-02-06 09:40:44 -08:00
Adam Paszke
a61d8002e5 [Mosaic GPU] Relax TMEM stride constraints on dimensions of size 1
Strides along those dimensions don't affect anything.

PiperOrigin-RevId: 723968994
2025-02-06 09:38:54 -08:00
Adam Paszke
026b6c9704 [Mosaic GPU] Take TMEM as a TMEMRef in tcgen05.mma, not as a raw address
PiperOrigin-RevId: 723936021
2025-02-06 07:59:58 -08:00
jax authors
5d647ccfa1 Merge pull request #26348 from gnecula:debug_info_jaxpr_3
PiperOrigin-RevId: 723920031
2025-02-06 06:59:18 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
George Necula
904b74860c [better_errors] Continue adding debug info to Jaxprs (step 3)
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.

As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
2025-02-06 16:26:49 +02:00
jax authors
c46b0215b0 Merge pull request #26313 from gnecula:debug_info_vjp
PiperOrigin-RevId: 723575296
2025-02-05 10:58:10 -08:00
Adam Paszke
6d6c8c2e6c [Mosaic GPU] Write a higher-level tcgen05.mma helper reusing WGMMA implementation
Hopper and Blackwell MMA instructions can share a lot of the same logic, which is
why I ended up splitting out a large fraction of WGMMA implementation into a common
utility. This should be an NFC for WGMMA, but it allows us to concisely implement
unrolling of MMAs of different sizes into a number of tcgen05.mma instructions.

PiperOrigin-RevId: 723544349
2025-02-05 09:38:29 -08:00
George Necula
abcaec7081 [better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
2025-02-05 19:21:02 +02:00
Adam Paszke
f4dab0cf72 [Mosaic GPU] Add helpers for dealing with TMEM references + implement optimized loads
The previous example implementation loaded TMEM in a layout that was very hard to
efficiently store into SMEM or GMEM. With the new TMEMRef abstraction, we can implement
loads that yield a FragmentedArray with a new tiled layout that allows for efficient
swizzled stores to SMEM.

The new layout is very similar to the one we've been using for WGMMA on Hopper, only the
initial row tiling is increased to 128 (making each warp hold 32 rows, not 16 as previously).

PiperOrigin-RevId: 723506876
2025-02-05 07:41:40 -08:00
Adam Paszke
b79ab01ee7 [Mosaic GPU] Refactor the Blackwell matmul example and make it runnable
The previous impelmentation depends on LLVM intrinsics that have not been submitted
yet. This replaces them with inline PTX (as far as I can tell there's no downside to
that) that's encapsulated into convenience functions.

PiperOrigin-RevId: 723498248
2025-02-05 07:11:03 -08:00
Adam Paszke
1fbc4a15dd [Mosaic GPU] Infer whether A/B are row- or column-major from strides
There's no need to require extra arguments. This makes our calling convention
saner since the logical dimension order stays the same (e.g. for B it's always
k before n in the shape), only the in-memory representation changes.

Other than the API change, this is a NFC.

PiperOrigin-RevId: 723449720
2025-02-05 04:01:04 -08:00
jax authors
d6be2351d4 Merge pull request #26159 from andportnoy:aportnoy/mosaic-gpu-blackwell-simple-matmul
PiperOrigin-RevId: 723428689
2025-02-05 02:33:57 -08:00
Yash Katariya
307006e194 Set the mesh as manual during partial_eval_custom in shard_map so that _add_reshapes happens under the correct mesh.
PiperOrigin-RevId: 723268798
2025-02-04 16:36:08 -08:00
Sharad Vikram
02f4531310 [Pallas TPU] Add helpers for writing collectives
PiperOrigin-RevId: 723250661
2025-02-04 15:39:10 -08:00
Andrey Portnoy
aff2cba898 [Mosaic GPU] Add simple Blackwell matmul example 2025-02-04 11:29:59 -05:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
Jacques Pienaar
60d3836fdf Propagate source ranges in location.
Previously only the line info was propagated. Given the new source range location support, propagate source range.

PiperOrigin-RevId: 722860932
2025-02-03 17:32:59 -08:00
jax authors
40d35b4219 Merge pull request #26277 from justinjfu:sourcemap_windows_fix
PiperOrigin-RevId: 722736034
2025-02-03 11:37:27 -08:00
Justin Fu
6d7b03572c Format sourcemap directory names to work on windows 2025-02-03 09:39:56 -08:00
Christos Perivolaropoulos
bf9671731c [mgpu] Correct instruction for conversion of unsigned int types.
PiperOrigin-RevId: 721793849
2025-01-31 09:06:40 -08:00
Adam Paszke
cadfcc7a1b [Mosaic GPU] Allow uneven partitioning of dimensions into tiles in TileTransform
PiperOrigin-RevId: 721705218
2025-01-31 03:05:44 -08:00
Adam Paszke
10ac6b7e12 [Mosaic GPU] Add support for tiled swizzle=16 (i.e. no swizzle) loads and stores
The tiling still makes it possible to do it without bank conflicts.

PiperOrigin-RevId: 721701635
2025-01-31 02:49:59 -08:00
Peter Hawkins
0705ec2ca4 Pass filter=data to tar extractall to avoid a warning under Python 3.12+
PiperOrigin-RevId: 721571944
2025-01-30 17:27:24 -08:00
Yash Katariya
9107ee4a22 Do automatic casting from auto -> manual when the context mesh is manual and avals are in auto mode. This happens when values are being closed over in a shard_map. The casting is happening at lax level but we can move this to a different place later on.
PiperOrigin-RevId: 721495804
2025-01-30 13:14:04 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
Benjamin Chetioui
46512e684b [Mosaic GPU][NFC] Fix wrong type annotations, and do some NFC cleanups.
PiperOrigin-RevId: 721350296
2025-01-30 05:13:58 -08:00
Yash Katariya
d223dfc3f7 Allow multiple meshes for avals but in that case, just use empty_abstract_mesh instead of enabling computation follows data only for **Auto mode**.
PiperOrigin-RevId: 721224349
2025-01-29 20:47:34 -08:00
Justin Fu
b01111d96c Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
2025-01-29 15:01:43 -08:00