mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 04:26:07 +00:00

For example: Consider this einsum: `jnp.einsum('bthD, bthi, bthj->ijD', dy, i, j, out_sharding=P('data', None, None))` This will decompose into 2 einsums where the intermediate einsum output will be of rank `5`: * `'bthj,bthD->bthjD'` * `'bthjD,bthi->ijD'` The out_sharding specified (`P('data', None, None)`) is not compatible with the intermediate einsum: `'bthj,bthD->bthjD'` since the `length of spec (3) != out_aval.ndim (5)`. This change makes it so that out_sharding is only applied to the contraction that leads to the final output. **If there are conflicts in intermediate einsums, then the user has to reshard the input or split into multiple einsums (and maybe provide out_sharding) so that conflicts don't exist.** Note: We won't drop into auto mode for intermediate einsums. The user will have to split the einsum if any conflict is detected. PiperOrigin-RevId: 732205849