mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

This version emits a StableHLO custom call. The test outputs the following MLIR module: ``` module @jit_ragged_all_to_all { func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) { %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32> return %0 : tensor<6xf32> } } ``` For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above). The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all. PiperOrigin-RevId: 704550890