Yash Katariya da1cc0a50e [sharding_in_types] out_sharding argument on einsum should only apply to the last einsum and not intermediate einsums.
For example: Consider this einsum: `jnp.einsum('bthD, bthi, bthj->ijD', dy, i, j, out_sharding=P('data', None, None))`

This will decompose into 2 einsums where the intermediate einsum output will be of rank `5`:
  * `'bthj,bthD->bthjD'`
  * `'bthjD,bthi->ijD'`

The out_sharding specified (`P('data', None, None)`) is not compatible with the intermediate einsum: `'bthj,bthD->bthjD'` since the `length of spec (3) != out_aval.ndim (5)`.

This change makes it so that out_sharding is only applied to the contraction that leads to the final output. **If there are conflicts in intermediate einsums, then the user has to reshard the input or split into multiple einsums (and maybe provide out_sharding) so that conflicts don't exist.**

Note: We won't drop into auto mode for intermediate einsums. The user will have to split the einsum if any conflict is detected.
PiperOrigin-RevId: 732205849
2025-02-28 11:39:14 -08:00
..