``` with mesh: out = pjit(lambda: 1)() ``` The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead. This is also required for `Shardy` integration. PiperOrigin-RevId: 658842350
pyupgrade --py310-plus
xmap
jax.experimental.maps
XLACompatibleSharding
jax.sharding.Sharding
Specialized
Traced
specialize
trace
KeyPath
jax.tree_util