Yash Katariya e6851e6b22 Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.

This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).

PiperOrigin-RevId: 658794885
2024-08-02 08:15:32 -07:00
..
2024-05-28 23:23:51 -04:00
2024-07-30 05:39:19 +02:00
2024-05-25 17:46:01 +00:00
2024-07-31 13:23:12 +03:00
2024-06-25 09:02:32 -07:00
2024-07-29 17:17:22 +00:00
2024-07-15 12:54:00 -07:00