also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.
PiperOrigin-RevId: 735957604
A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations.
PiperOrigin-RevId: 735875835
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:
```
Contributing CTA | 0 | 1 | 0 | 1 |
N local | 0:128 | 0:128 | 128:256 | 128:256 |
N | 0:128 | 256:384 | 128:256 | 384:512 |
```
You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!
We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).
Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.
PiperOrigin-RevId: 735791426
Imported from GitHub PR https://github.com/openxla/xla/pull/22800
Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.
The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.
Copybara import of the project:
--
e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>:
Change the default value of print_operand_shape_ to false and print_large_constants_ to true.
Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.
The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.
--
7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>:
Handle tests
--
b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>:
Fix more tests
--
d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>:
Fix more tests
Merging this change closes#22800
PiperOrigin-RevId: 735690598
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.
PiperOrigin-RevId: 735602187
I was testing the RC promotion workflow and found that the version test failed as it does not consider pre-releases. Therefore, this commit modifies the `VERSION_PATTERN` to also consider "rc" wheels.
Fixes https://github.com/jax-ml/jax/actions/runs/13705984545/job/38331236497
PiperOrigin-RevId: 735444828
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)
Also make some formatting changes in scan lowering to make it easier to debug.
PiperOrigin-RevId: 734542862
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.
PiperOrigin-RevId: 734517000
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.
PiperOrigin-RevId: 734299259