diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7e94df959..ec3395049 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -69,7 +69,8 @@ from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, - is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources + is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources, + SingleDeviceSharding, GSPMDSharding ) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, @@ -2340,11 +2341,9 @@ def get_gspmd_shardings_from_executable( assert len(omk) == num_out_avals, (len(omk), num_out_avals) # When the device assignment only has 1 device, SPMD partitioner will not run. - # Hence the op shardings will not be set on the `hlo_module`. In that case, - # just return SingleDeviceShardings since we know the computation is running - # only on 1 device. + # Hence the op shardings will not be set on the `hlo_module`. if len(device_assignment) == 1: - return [sharding_impls.SingleDeviceSharding(device_assignment[0], memory_kind=mk) + return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk) for mk in omk] _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) @@ -2413,13 +2412,23 @@ _register_out_sharding_handler( def _gspmd_to_positional_sharding( out_s: sharding_impls.GSPMDSharding, - orig_in_s: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding: + orig_in_s: sharding_impls.PositionalSharding + ) -> sharding_impls.PositionalSharding: return sharding_impls._op_sharding_to_pos_sharding( out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind) _register_out_sharding_handler( sharding_impls.PositionalSharding, _gspmd_to_positional_sharding) +def _gspmd_to_single_device_sharding( + out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding: + assert isinstance(orig_in_s, SingleDeviceSharding) + return SingleDeviceSharding( + out_s._device_assignment[0], memory_kind=out_s.memory_kind) + +_register_out_sharding_handler( + SingleDeviceSharding, _gspmd_to_single_device_sharding) + def _get_out_sharding_from_orig_sharding( out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla): @@ -2615,6 +2624,15 @@ def _get_shardings_from_executable( return out_shardings, are_out_shardings_from_xla +def finalize_out_shardings(out_shardings, are_out_shardings_from_xla, + device_assignment): + if len(device_assignment) == 1: + return ([SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) + if isinstance(o, GSPMDSharding) else o for o in out_shardings], + are_out_shardings_from_xla) + return out_shardings, are_out_shardings_from_xla + + @dataclasses.dataclass class UnloadedMeshExecutable: xla_executable: Any @@ -2762,6 +2780,9 @@ class UnloadedMeshExecutable: in_shardings, out_shardings, are_out_shardings_from_xla, global_in_avals, global_out_avals) + out_shardings, are_out_shardings_from_xla = finalize_out_shardings( + out_shardings, are_out_shardings_from_xla, da) + return UnloadedMeshExecutable( xla_executable=xla_executable, device_assignment=da, # type: ignore diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2f5a44ef0..1d67425b7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3802,6 +3802,16 @@ class ArrayPjitTest(jtu.JaxTestCase): g(np.arange(8)) self.assertEqual(count[0], 2) + def test_single_device_named_sharding_preserved(self): + mesh = jax.sharding.Mesh([jax.devices()[0]], 'x') + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(8) + inp = jax.device_put(np_inp, s) + + out = jax.jit(lambda x: x)(inp) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np_inp) + class TempSharding(Sharding): diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 814fd7089..291ee5360 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -291,6 +291,14 @@ class ShardAlikeTest(jtu.JaxTestCase): self.assertArraysEqual(out1, np_inp) self.assertArraysEqual(out2, np_inp2.T) + def test_sharding_preserverd_single_device(self): + mesh = jax.sharding.Mesh([jax.devices()[0]], "x") + s = NamedSharding(mesh, P("x")) + + x = jax.device_put(np.arange(8), s) + _, y = shard_alike(x, jnp.arange(8)) + self.assertEqual(y.sharding, s) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())