This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind. PiperOrigin-RevId: 660913387
pyupgrade --py310-plus
xmap
jax.experimental.maps
XLACompatibleSharding
jax.sharding.Sharding
Specialized
Traced
specialize
trace
KeyPath
jax.tree_util