diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py
index 8507e3e91..d3e8c22cf 100644
--- a/jax/_src/interpreters/mlir.py
+++ b/jax/_src/interpreters/mlir.py
@@ -1030,6 +1030,8 @@ def _to_physical_op_sharding(
 ) -> xc.OpSharding | SdyArraySharding | None:
   if sharding is None:
     return None
+  if all_unconstrained(sharding, aval):
+    return None
   if isinstance(sharding, AUTO):
     if config.use_shardy_partitioner.value:
       return sharding._to_sdy_sharding(aval.ndim)  # type: ignore
@@ -1071,10 +1073,8 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
 
 
 def contains_unconstrained(s):
-  return (
-      isinstance(s, NamedSharding)
-      and PartitionSpec.UNCONSTRAINED in s._parsed_pspec
-  )
+  return (isinstance(s, NamedSharding)
+          and PartitionSpec.UNCONSTRAINED in s._parsed_pspec)
 
 
 def all_unconstrained(s, aval):
@@ -1084,12 +1084,19 @@ def all_unconstrained(s, aval):
     return all(p is PartitionSpec.UNCONSTRAINED for p in s._parsed_pspec)
   return False
 
-def _get_unconstrained_dimensions(s, aval):
+class UnconstrainedVariants(NamedTuple):
+  contains_unconstrained: bool
+  all_unconstrained: bool
+  unconstrained_dims: set[int] | None
+
+def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants:
   us = contains_unconstrained(s)
-  return (
-    us, all_unconstrained(s, aval),
-    ({i for i, p in enumerate(s._parsed_pspec)
-      if p is PartitionSpec.UNCONSTRAINED} if us else None))
+  unconstrained_dims = ({i for i, p in enumerate(s._parsed_pspec)
+                         if p is PartitionSpec.UNCONSTRAINED} if us else None)
+  return UnconstrainedVariants(
+      contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval),
+      unconstrained_dims=unconstrained_dims)
+
 
 def lower_jaxpr_to_module(
     module_name: str,
@@ -1511,13 +1518,13 @@ def lower_jaxpr_to_fun(
          for is_donated, types in zip(xla_donated_args, input_types)])
 
   ir_result_shardings = None
-  unconstrained_shardings = None
+  unconstrained_variants = None
   if result_shardings is not None:
     ir_result_shardings = util.flatten(
         [[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types)
          for a, s, types in zip(output_avals, result_shardings, output_types)])
-    unconstrained_shardings = util.flatten(
-        [[_get_unconstrained_dimensions(s, a)] * len_ir_types(types)
+    unconstrained_variants = util.flatten(
+        [[_get_unconstrained_variants(s, a)] * len_ir_types(types)
          for a, s, types in zip(output_avals, result_shardings, output_types)])
 
   ir_result_memory_kinds = None
@@ -1633,9 +1640,9 @@ def lower_jaxpr_to_fun(
         attrs['jax.result_info'] = ir.StringAttr.get(name_)
 
   if use_sharding_annotations and ir_result_shardings is not None:
-    for attrs, sharding, us in zip(result_attrs, ir_result_shardings,
-                                   unconstrained_shardings):  # type: ignore
-      if sharding is not None and not us[0]:
+    for attrs, sharding, uv in zip(result_attrs, ir_result_shardings,
+                                   unconstrained_variants):  # type: ignore
+      if sharding is not None and not uv.contains_unconstrained:
         if config.use_shardy_partitioner.value:
           attrs["sdy.sharding"] = get_sharding_attr(sharding)
         else:
@@ -1716,13 +1723,15 @@ def lower_jaxpr_to_fun(
 
     if ir_result_shardings is not None:
       temp_flat_outputs = []
-      for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
-                                  output_avals, unconstrained_shardings):  # type: ignore
-        if us[0] and not us[1]:
+      for o, s, o_aval, uv in zip(flat_outputs, ir_result_shardings,
+                                  output_avals, unconstrained_variants):  # type: ignore
+        if (s is not None and uv.contains_unconstrained and
+            not uv.all_unconstrained):
           if config.use_shardy_partitioner.value:
             s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
           temp_flat_outputs.append(wrap_with_sharding_op(
-              entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
+              entry_lowering_ctx, o, o_aval, s,
+              unspecified_dims=uv.unconstrained_dims))
         else:
           temp_flat_outputs.append(o)
       flat_outputs = temp_flat_outputs
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index 07571c7ec..e6c7ac986 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -2157,6 +2157,13 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
   return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
           donated_invars, out_shardings, out_layouts)
 
+@lru_cache(maxsize=1024)
+def _abstract_to_concrete_mesh(abstract_mesh, device_assignment):
+  np_dev = np.vectorize(lambda i: device_assignment[i],
+                        otypes=[object])(np.arange(len(device_assignment)))
+  return Mesh(np_dev.reshape(abstract_mesh.axis_sizes),
+              abstract_mesh.axis_names, axis_types=abstract_mesh.axis_types)
+
 def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
                                        out_mem_kinds):
   if device_assignment is None:
@@ -2164,27 +2171,20 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
   if len(device_assignment) == 1:
     return shardings
 
-  np_dev = np.vectorize(lambda i: device_assignment[i],
-                        otypes=[object])(np.arange(len(device_assignment)))
-
-  @lru_cache(maxsize=128)
-  def _abstract_to_concrete_mesh(abstract_mesh):
-    return Mesh(
-        np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
-        axis_types=abstract_mesh.axis_types)
-
   out = []
   for s, a, mem_kind in zip(shardings, avals, out_mem_kinds):
     if isinstance(s, UnspecifiedValue) and a.sharding is not None:
       if a.sharding.mesh.empty:
         out.append(s)
+      elif a.sharding.mesh._are_all_axes_auto:
+        out.append(s)
       else:
         spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
                                 for sp in a.sharding.spec])
                 if a.sharding.mesh._any_axis_auto else a.sharding.spec)
         out.append(NamedSharding(
-            _abstract_to_concrete_mesh(a.sharding.mesh), spec,
-            memory_kind=mem_kind))
+            _abstract_to_concrete_mesh(a.sharding.mesh, device_assignment),
+            spec, memory_kind=mem_kind))
     else:
       out.append(s)
   return tuple(out)
@@ -2534,15 +2534,22 @@ def _get_mesh_pspec_shardings_from_executable(
 _orig_out_sharding_handlers = {}
 
 def _gspmd_to_named_sharding(
-    out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
+    out_s: GSPMDSharding, out_aval, orig_in_s: NamedSharding) -> NamedSharding:
   assert isinstance(out_s, GSPMDSharding)
   assert isinstance(orig_in_s, NamedSharding)
   assert isinstance(orig_in_s.mesh, Mesh)
-  return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
+  if (out_aval is not None and not out_aval.sharding.mesh.empty and
+      out_aval.sharding.mesh._are_all_axes_auto):
+    mesh = _abstract_to_concrete_mesh(
+        out_aval.sharding.mesh, out_s._device_assignment)
+  else:
+    mesh = orig_in_s.mesh
+  return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh)
 _orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding
 
 def _gspmd_to_positional_sharding(
-    out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding:
+    out_s: GSPMDSharding, out_aval, orig_in_s: PositionalSharding
+    ) -> PositionalSharding:
   assert isinstance(out_s, GSPMDSharding)
   assert isinstance(orig_in_s, PositionalSharding)
   return sharding_impls._op_sharding_to_pos_sharding(
@@ -2550,7 +2557,8 @@ def _gspmd_to_positional_sharding(
 _orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding  # type: ignore
 
 def _gspmd_to_single_device_sharding(
-    out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
+    out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding
+    ) -> SingleDeviceSharding:
   assert isinstance(out_s, GSPMDSharding)
   assert isinstance(orig_in_s, SingleDeviceSharding)
   return SingleDeviceSharding(
@@ -2565,15 +2573,17 @@ def _get_out_sharding_from_orig_sharding(
   for o, out_aval in safe_zip(out_shardings, out_avals):
     if (isinstance(o, sharding_impls.GSPMDSharding) and
         out_aval is not core.abstract_token):
-      if (orig_aval is not None and out_aval is not None and
-          out_aval.ndim == orig_aval.ndim
-          and sharding_impls.are_op_shardings_equal(
-              o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
-          and o.memory_kind == orig_in_s.memory_kind):
+      # TODO(yashkatariya): Remove this condition and ask users to drop into
+      # explicit mode.
+      if (orig_aval is not None and out_aval is not None
+          and out_aval.ndim == orig_aval.ndim
+          and isinstance(orig_in_s, NamedSharding)
+          and out_aval.sharding.mesh == orig_in_s.mesh.abstract_mesh
+          and o.is_equivalent_to(orig_in_s, orig_aval.ndim)):
         out.append(orig_in_s)
       else:
         try:
-          out.append(orig_handler(o, orig_in_s))
+          out.append(orig_handler(o, out_aval, orig_in_s))
         except:
           out.append(o)
     else:
@@ -2581,40 +2591,9 @@ def _get_out_sharding_from_orig_sharding(
   return out
 
 
-def try_matching_out_with_in_spec_for_all_auto(
-    orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
-  recover_in_s, recover_in_aval = None, None
-  for in_s, in_aval in safe_zip(in_shardings, in_avals):
-    if isinstance(in_s, NamedSharding):
-      recover_in_s, recover_in_aval = in_s, in_aval
-      break
-  if recover_in_s is None:
-    return new_out_shardings
-
-  res = []
-  for orig_out_s, out_s, out_aval in safe_zip(
-      orig_out_shardings, new_out_shardings, out_avals):
-    if (out_aval is not core.abstract_token and
-        mlir.all_unconstrained(orig_out_s, out_aval) and
-        isinstance(orig_out_s, NamedSharding) and
-        isinstance(out_s, NamedSharding) and
-        orig_out_s.mesh._are_all_axes_auto and out_s.mesh._are_all_axes_auto and
-        out_aval.ndim == recover_in_aval.ndim and
-        out_s.is_equivalent_to(recover_in_s, out_aval.ndim)):
-      res.append(out_s.with_spec(recover_in_s.spec))
-    else:
-      res.append(out_s)
-  return res
-
-
 def maybe_recover_user_shardings(
     old_shardings, new_shardings, old_avals, new_avals,
-    intermediate_shardings=None, context_mesh: Mesh | None = None,
-    orig_out_shardings=None):
-  if orig_out_shardings is not None:
-    new_shardings = try_matching_out_with_in_spec_for_all_auto(
-        orig_out_shardings, new_shardings, new_avals, old_shardings, old_avals)
-
+    intermediate_shardings=None, context_mesh: Mesh | None = None):
   if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
     return new_shardings
 
@@ -2831,7 +2810,7 @@ def _maybe_get_and_check_out_shardings(
           dtypes.issubdtype(aval.dtype, dtypes.extended)):
         xla_s = sharding_impls.logical_sharding(aval, xla_s)
       try:
-        new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig))  # pytype: disable=wrong-arg-types
+        new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig))  # pytype: disable=wrong-arg-types
       except:
         new_out_shardings.append(xla_s)
     else:
@@ -3004,7 +2983,7 @@ class UnloadedMeshExecutable:
 
     out_shardings = maybe_recover_user_shardings(
         in_shardings, out_shardings, global_in_avals, global_out_avals,
-        intermediate_shardings, context_mesh, orig_out_shardings)
+        intermediate_shardings, context_mesh)
 
     in_shardings = finalize_shardings(in_shardings, da)
     out_shardings = finalize_shardings(out_shardings, da)
diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py
index 9c6cf0861..2ad3e089e 100644
--- a/tests/shard_alike_test.py
+++ b/tests/shard_alike_test.py
@@ -131,7 +131,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
       return shard_alike(x, y)[1]
 
     out = f(inp)
-    self.assertEqual(out.sharding, s)
+    self.assertTrue(out.sharding.is_equivalent_to(s, out.ndim))
     self.assertArraysEqual(out, np_inp)
 
   def test_shard_map(self):
@@ -268,7 +268,8 @@ class ShardAlikeTest(jtu.JaxTestCase):
 
     x = jax.device_put(np.arange(8), s)
     _, y = shard_alike(x, jnp.arange(8))
-    self.assertEqual(y.sharding, s)
+    self.assertTrue(y.sharding.is_equivalent_to(s, y.ndim))
+
 
 if __name__ == '__main__':
   absltest.main(testLoader=jtu.JaxTestLoader())