This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:
```
Contributing CTA | 0 | 1 | 0 | 1 |
N local | 0:128 | 0:128 | 128:256 | 128:256 |
N | 0:128 | 256:384 | 128:256 | 384:512 |
```
You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!
We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).
Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.
PiperOrigin-RevId: 735791426
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.
There are three options for running the tests:
1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.
PiperOrigin-RevId: 735765819
Imported from GitHub PR https://github.com/openxla/xla/pull/22800
Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.
The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.
Copybara import of the project:
--
e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>:
Change the default value of print_operand_shape_ to false and print_large_constants_ to true.
Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.
The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.
--
7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>:
Handle tests
--
b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>:
Fix more tests
--
d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>:
Fix more tests
Merging this change closes#22800
PiperOrigin-RevId: 735690598
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.
PiperOrigin-RevId: 735602187
I was testing the RC promotion workflow and found that the version test failed as it does not consider pre-releases. Therefore, this commit modifies the `VERSION_PATTERN` to also consider "rc" wheels.
Fixes https://github.com/jax-ml/jax/actions/runs/13705984545/job/38331236497
PiperOrigin-RevId: 735444828
My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.
It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.
This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.
PiperOrigin-RevId: 735381736
Surprisingly, the bug was tracked down to #26111 aka cl/730939406, specifically
the new implementation of reset_name_stack in source_info_util.py.
To repro, use the before-this-commit implementation of reset_name_stack (left
commented-out in the file), and run
```
JAX_USE_DIRECT_LINEARIZE=1 python tests/name_stack_test.py NameStackTransformationTest.test_nested_jit_stack
```