diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6c9e54441..2164c1a91 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2747,11 +2747,11 @@ def _maybe_get_and_check_out_shardings( return new_out_shardings -def finalize_out_shardings(out_shardings, device_assignment): +def finalize_shardings(shardings, device_assignment): if len(device_assignment) == 1: return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) - if isinstance(o, GSPMDSharding) else o for o in out_shardings] - return out_shardings + if isinstance(o, GSPMDSharding) else o for o in shardings] + return shardings @dataclasses.dataclass @@ -2892,7 +2892,8 @@ class UnloadedMeshExecutable: in_shardings, out_shardings, global_in_avals, global_out_avals, intermediate_shardings, context_mesh) - out_shardings = finalize_out_shardings(out_shardings, da) + in_shardings = finalize_shardings(in_shardings, da) + out_shardings = finalize_shardings(out_shardings, da) return UnloadedMeshExecutable( xla_executable=xla_executable, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index be1f9cfc2..0c1c28809 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4624,6 +4624,14 @@ class ArrayPjitTest(jtu.JaxTestCase): jax.jit(f, out_shardings=s)(np.arange(8)) self.assertEqual(count[0], 1) + def test_input_shardings_single_device(self): + @jax.jit + def f(x): + return x * 2 + + ins, _ = f.lower(np.arange(8)).compile().input_shardings + self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")