Return SingleDeviceSharding instead of GSPMDShardings when there is only 1 device during compiled.input_shardings call.

PiperOrigin-RevId: 697683233
This commit is contained in:
Yash Katariya 2024-11-18 10:44:59 -08:00 committed by jax authors
parent 297a4e5ef5
commit 6fe7b1713a
2 changed files with 13 additions and 4 deletions

View File

@ -2747,11 +2747,11 @@ def _maybe_get_and_check_out_shardings(
return new_out_shardings
def finalize_out_shardings(out_shardings, device_assignment):
def finalize_shardings(shardings, device_assignment):
if len(device_assignment) == 1:
return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
if isinstance(o, GSPMDSharding) else o for o in out_shardings]
return out_shardings
if isinstance(o, GSPMDSharding) else o for o in shardings]
return shardings
@dataclasses.dataclass
@ -2892,7 +2892,8 @@ class UnloadedMeshExecutable:
in_shardings, out_shardings, global_in_avals, global_out_avals,
intermediate_shardings, context_mesh)
out_shardings = finalize_out_shardings(out_shardings, da)
in_shardings = finalize_shardings(in_shardings, da)
out_shardings = finalize_shardings(out_shardings, da)
return UnloadedMeshExecutable(
xla_executable=xla_executable,

View File

@ -4624,6 +4624,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
jax.jit(f, out_shardings=s)(np.arange(8))
self.assertEqual(count[0], 1)
def test_input_shardings_single_device(self):
@jax.jit
def f(x):
return x * 2
ins, _ = f.lower(np.arange(8)).compile().input_shardings
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")