Yash Katariya e6303244bf If the memory kind is the default kind throughout the jaxpr, then revert back to the previous device_put behavior which was a no-op inside jit.
This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind.

PiperOrigin-RevId: 660913387
2024-08-08 11:24:25 -07:00
..
2024-08-06 11:22:27 -07:00
2023-07-24 14:38:20 -07:00
2023-10-10 08:46:36 -07:00