1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-22 15:36:08 +00:00

7 Commits

Author SHA1 Message Date
Matthew Johnson
66a6eb299e add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.

PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Adam Paszke
26d8e112e3 Add a missing jaxlib version check in ragged_collective_test
PiperOrigin-RevId: 725186144
2025-02-10 06:12:36 -08:00
Gunhyun Park
c4e176328f Move ragged_all_to_all test under appropriate test file
PiperOrigin-RevId: 721947980
2025-01-31 16:44:04 -08:00
Gunhyun Park
a8df383ccf Fix lax.ragged_all_to_all degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.

Added small improvement to error messages.

PiperOrigin-RevId: 721473063
2025-01-30 12:05:02 -08:00
Dan Foreman-Mackey
19c17bb28b Skip ragged collective tests on CPU. 2025-01-30 10:03:53 -05:00
Gunhyun Park
809e1133c8 Add support for axis_name and axis_index_groups to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
2025-01-28 16:02:03 -08:00