mirror of
https://github.com/ROCm/jax.git
synced 2025-04-22 17:16:06 +00:00

If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`. We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions. PiperOrigin-RevId: 713352542