This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context.
PiperOrigin-RevId: 735909962
A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations.
PiperOrigin-RevId: 735875835
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:
```
Contributing CTA | 0 | 1 | 0 | 1 |
N local | 0:128 | 0:128 | 128:256 | 128:256 |
N | 0:128 | 256:384 | 128:256 | 384:512 |
```
You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!
We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).
Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.
PiperOrigin-RevId: 735791426
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.
There are three options for running the tests:
1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.
PiperOrigin-RevId: 735765819
Imported from GitHub PR https://github.com/openxla/xla/pull/22800
Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.
The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.
Copybara import of the project:
--
e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>:
Change the default value of print_operand_shape_ to false and print_large_constants_ to true.
Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.
The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.
--
7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>:
Handle tests
--
b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>:
Fix more tests
--
d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>:
Fix more tests
Merging this change closes#22800
PiperOrigin-RevId: 735690598
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.
PiperOrigin-RevId: 735602187
I was testing the RC promotion workflow and found that the version test failed as it does not consider pre-releases. Therefore, this commit modifies the `VERSION_PATTERN` to also consider "rc" wheels.
Fixes https://github.com/jax-ml/jax/actions/runs/13705984545/job/38331236497
PiperOrigin-RevId: 735444828