mirror of
https://github.com/ROCm/jax.git
synced 2025-04-22 23:36:05 +00:00

We add `None`'s when ndim > len(sharding.spec) and only remove `None`s when `ndim < len(sharding.spec)`. If sharded axes exist, then we error out when removing specs. PiperOrigin-RevId: 748735303