mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Preserve single device NamedSharding/PositionalSharding on the output instead of always return SingleDeviceShardings.
Fixes https://github.com/google/jax/issues/19459 PiperOrigin-RevId: 600999853
This commit is contained in:
parent
eb59716e27
commit
6f96c963ff
@ -69,7 +69,8 @@ from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
||||
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
||||
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources
|
||||
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
|
||||
SingleDeviceSharding, GSPMDSharding
|
||||
)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_update, tuple_delete,
|
||||
@ -2340,11 +2341,9 @@ def get_gspmd_shardings_from_executable(
|
||||
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
|
||||
|
||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
||||
# just return SingleDeviceShardings since we know the computation is running
|
||||
# only on 1 device.
|
||||
# Hence the op shardings will not be set on the `hlo_module`.
|
||||
if len(device_assignment) == 1:
|
||||
return [sharding_impls.SingleDeviceSharding(device_assignment[0], memory_kind=mk)
|
||||
return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk)
|
||||
for mk in omk]
|
||||
|
||||
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
|
||||
@ -2413,13 +2412,23 @@ _register_out_sharding_handler(
|
||||
|
||||
def _gspmd_to_positional_sharding(
|
||||
out_s: sharding_impls.GSPMDSharding,
|
||||
orig_in_s: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
|
||||
orig_in_s: sharding_impls.PositionalSharding
|
||||
) -> sharding_impls.PositionalSharding:
|
||||
return sharding_impls._op_sharding_to_pos_sharding(
|
||||
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
|
||||
|
||||
_register_out_sharding_handler(
|
||||
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding)
|
||||
|
||||
def _gspmd_to_single_device_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
|
||||
assert isinstance(orig_in_s, SingleDeviceSharding)
|
||||
return SingleDeviceSharding(
|
||||
out_s._device_assignment[0], memory_kind=out_s.memory_kind)
|
||||
|
||||
_register_out_sharding_handler(
|
||||
SingleDeviceSharding, _gspmd_to_single_device_sharding)
|
||||
|
||||
|
||||
def _get_out_sharding_from_orig_sharding(
|
||||
out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla):
|
||||
@ -2615,6 +2624,15 @@ def _get_shardings_from_executable(
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
|
||||
|
||||
def finalize_out_shardings(out_shardings, are_out_shardings_from_xla,
|
||||
device_assignment):
|
||||
if len(device_assignment) == 1:
|
||||
return ([SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
|
||||
if isinstance(o, GSPMDSharding) else o for o in out_shardings],
|
||||
are_out_shardings_from_xla)
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UnloadedMeshExecutable:
|
||||
xla_executable: Any
|
||||
@ -2762,6 +2780,9 @@ class UnloadedMeshExecutable:
|
||||
in_shardings, out_shardings, are_out_shardings_from_xla,
|
||||
global_in_avals, global_out_avals)
|
||||
|
||||
out_shardings, are_out_shardings_from_xla = finalize_out_shardings(
|
||||
out_shardings, are_out_shardings_from_xla, da)
|
||||
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
device_assignment=da, # type: ignore
|
||||
|
@ -3802,6 +3802,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
def test_single_device_named_sharding_preserved(self):
|
||||
mesh = jax.sharding.Mesh([jax.devices()[0]], 'x')
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.arange(8)
|
||||
inp = jax.device_put(np_inp, s)
|
||||
|
||||
out = jax.jit(lambda x: x)(inp)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
@ -291,6 +291,14 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out1, np_inp)
|
||||
self.assertArraysEqual(out2, np_inp2.T)
|
||||
|
||||
def test_sharding_preserverd_single_device(self):
|
||||
mesh = jax.sharding.Mesh([jax.devices()[0]], "x")
|
||||
s = NamedSharding(mesh, P("x"))
|
||||
|
||||
x = jax.device_put(np.arange(8), s)
|
||||
_, y = shard_alike(x, jnp.arange(8))
|
||||
self.assertEqual(y.sharding, s)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user