mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

Previously it was: `ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x` Now it is: `TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]` PiperOrigin-RevId: 736657644