mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Return SingleDeviceSharding instead of GSPMDShardings when there is only 1 device during compiled.input_shardings
call.
PiperOrigin-RevId: 697683233
This commit is contained in:
parent
297a4e5ef5
commit
6fe7b1713a
@ -2747,11 +2747,11 @@ def _maybe_get_and_check_out_shardings(
|
||||
return new_out_shardings
|
||||
|
||||
|
||||
def finalize_out_shardings(out_shardings, device_assignment):
|
||||
def finalize_shardings(shardings, device_assignment):
|
||||
if len(device_assignment) == 1:
|
||||
return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
|
||||
if isinstance(o, GSPMDSharding) else o for o in out_shardings]
|
||||
return out_shardings
|
||||
if isinstance(o, GSPMDSharding) else o for o in shardings]
|
||||
return shardings
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -2892,7 +2892,8 @@ class UnloadedMeshExecutable:
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals,
|
||||
intermediate_shardings, context_mesh)
|
||||
|
||||
out_shardings = finalize_out_shardings(out_shardings, da)
|
||||
in_shardings = finalize_shardings(in_shardings, da)
|
||||
out_shardings = finalize_shardings(out_shardings, da)
|
||||
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
|
@ -4624,6 +4624,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
jax.jit(f, out_shardings=s)(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_input_shardings_single_device(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
ins, _ = f.lower(np.arange(8)).compile().input_shardings
|
||||
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user