mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 05:56:05 +00:00

This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes. Note that we will be introducing some restrictions under Shardy for JAX export: - You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up) - When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up) We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export. PiperOrigin-RevId: 732878916