Preserve single device NamedSharding/PositionalSharding on the output instead of always return SingleDeviceShardings.

Fixes https://github.com/google/jax/issues/19459

PiperOrigin-RevId: 600999853
This commit is contained in:
Yash Katariya 2024-01-23 21:28:33 -08:00 committed by jax authors
parent eb59716e27
commit 6f96c963ff
3 changed files with 45 additions and 6 deletions

View File

@ -69,7 +69,8 @@ from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
SingleDeviceSharding, GSPMDSharding
)
from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_update, tuple_delete,
@ -2340,11 +2341,9 @@ def get_gspmd_shardings_from_executable(
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
# just return SingleDeviceShardings since we know the computation is running
# only on 1 device.
# Hence the op shardings will not be set on the `hlo_module`.
if len(device_assignment) == 1:
return [sharding_impls.SingleDeviceSharding(device_assignment[0], memory_kind=mk)
return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk)
for mk in omk]
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
@ -2413,13 +2412,23 @@ _register_out_sharding_handler(
def _gspmd_to_positional_sharding(
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
orig_in_s: sharding_impls.PositionalSharding
) -> sharding_impls.PositionalSharding:
return sharding_impls._op_sharding_to_pos_sharding(
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
_register_out_sharding_handler(
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding)
def _gspmd_to_single_device_sharding(
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
assert isinstance(orig_in_s, SingleDeviceSharding)
return SingleDeviceSharding(
out_s._device_assignment[0], memory_kind=out_s.memory_kind)
_register_out_sharding_handler(
SingleDeviceSharding, _gspmd_to_single_device_sharding)
def _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla):
@ -2615,6 +2624,15 @@ def _get_shardings_from_executable(
return out_shardings, are_out_shardings_from_xla
def finalize_out_shardings(out_shardings, are_out_shardings_from_xla,
device_assignment):
if len(device_assignment) == 1:
return ([SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
if isinstance(o, GSPMDSharding) else o for o in out_shardings],
are_out_shardings_from_xla)
return out_shardings, are_out_shardings_from_xla
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
@ -2762,6 +2780,9 @@ class UnloadedMeshExecutable:
in_shardings, out_shardings, are_out_shardings_from_xla,
global_in_avals, global_out_avals)
out_shardings, are_out_shardings_from_xla = finalize_out_shardings(
out_shardings, are_out_shardings_from_xla, da)
return UnloadedMeshExecutable(
xla_executable=xla_executable,
device_assignment=da, # type: ignore

View File

@ -3802,6 +3802,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
g(np.arange(8))
self.assertEqual(count[0], 2)
def test_single_device_named_sharding_preserved(self):
mesh = jax.sharding.Mesh([jax.devices()[0]], 'x')
s = NamedSharding(mesh, P('x'))
np_inp = np.arange(8)
inp = jax.device_put(np_inp, s)
out = jax.jit(lambda x: x)(inp)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp)
class TempSharding(Sharding):

View File

@ -291,6 +291,14 @@ class ShardAlikeTest(jtu.JaxTestCase):
self.assertArraysEqual(out1, np_inp)
self.assertArraysEqual(out2, np_inp2.T)
def test_sharding_preserverd_single_device(self):
mesh = jax.sharding.Mesh([jax.devices()[0]], "x")
s = NamedSharding(mesh, P("x"))
x = jax.device_put(np.arange(8), s)
_, y = shard_alike(x, jnp.arange(8))
self.assertEqual(y.sharding, s)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())