diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e6dda92f8..7d622393e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2574,23 +2574,23 @@ def _get_out_sharding_from_orig_sharding( out.append(o) return out -def maybe_get_orig_out_sharding( - in_shardings, out_shardings, in_avals, out_avals): - if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in out_shardings): - return out_shardings +def maybe_recover_user_shardings( + old_shardings, new_shardings, old_avals, new_avals): + if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings): + return new_shardings orig_in_s = None orig_aval = None - for oi, aval in safe_zip(in_shardings, in_avals): + for oi, aval in safe_zip(old_shardings, old_avals): if type(oi) in _orig_out_sharding_handlers: orig_in_s = oi orig_aval = aval break if orig_in_s is not None: return _get_out_sharding_from_orig_sharding( - out_shardings, out_avals, orig_in_s, orig_aval) + new_shardings, new_avals, orig_in_s, orig_aval) - return out_shardings + return new_shardings def _get_layouts_from_executable( @@ -2744,6 +2744,10 @@ def _maybe_get_and_check_in_shardings( f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " "(User sharding)") new_in_shardings.append(orig) + + new_in_shardings = maybe_recover_user_shardings( + in_shardings, new_in_shardings, global_in_avals, global_in_avals) + return new_in_shardings @@ -2921,7 +2925,7 @@ class UnloadedMeshExecutable: assert all(i is None for i in in_layouts) assert all(o is None for o in out_layouts) - out_shardings = maybe_get_orig_out_sharding( + out_shardings = maybe_recover_user_shardings( in_shardings, out_shardings, global_in_avals, global_out_avals) out_shardings = finalize_out_shardings(out_shardings, da) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ae6304a7c..f9963066d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3999,6 +3999,19 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertArraysEqual(out4, np_inp * 3) self.assertArraysEqual(out5, np_inp.T) + def test_input_shardings_aot(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x, y): + return x * 2, y.T + + arg_shardings, _ = f.lower(arr, np_inp).compile().input_shardings + for s in arg_shardings: + self.assertIsInstance(s, NamedSharding) + def test_parameter_tupled_jit(self): if not jtu.test_device_matches(["tpu"]): self.skipTest('Parameters are tupled only on TPU if >2000 parameters')