[shardy] Fix cases in shardy where you have a nullary function with partially specified out_shardings (i.e. some out_sharding's are None and others are NamedShardings).

In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's).

PiperOrigin-RevId: 714681941
This commit is contained in:
Yash Katariya 2025-01-12 08:08:35 -08:00 committed by jax authors
parent 4e7293162b
commit 6b253b2f75
2 changed files with 29 additions and 20 deletions

View File

@ -25,7 +25,7 @@ import functools
import itertools as it
import logging
import math
from typing import Any, NamedTuple, TypeVar, Union, cast
from typing import Any, NamedTuple, Union, cast
import warnings
import numpy as np
@ -2525,39 +2525,29 @@ def _get_mesh_pspec_shardings_from_executable(
_orig_out_sharding_handlers = {}
_ShardingT = TypeVar("_ShardingT", bound=JSharding)
def _register_out_sharding_handler(
sharding_cls: type[_ShardingT],
handler: Callable[[GSPMDSharding, _ShardingT], _ShardingT],
) -> None:
_orig_out_sharding_handlers[sharding_cls] = handler
def _gspmd_to_named_sharding(
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, NamedSharding)
assert isinstance(orig_in_s.mesh, Mesh)
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)
_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding # type: ignore
def _gspmd_to_positional_sharding(
out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, PositionalSharding)
return sharding_impls._op_sharding_to_pos_sharding(
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
_register_out_sharding_handler(
PositionalSharding, _gspmd_to_positional_sharding)
_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore
def _gspmd_to_single_device_sharding(
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, SingleDeviceSharding)
return SingleDeviceSharding(
out_s._device_assignment[0], memory_kind=out_s.memory_kind)
_register_out_sharding_handler(
SingleDeviceSharding, _gspmd_to_single_device_sharding)
_orig_out_sharding_handlers[SingleDeviceSharding] = _gspmd_to_single_device_sharding # type: ignore
def _get_out_sharding_from_orig_sharding(
@ -2601,7 +2591,13 @@ def maybe_recover_user_shardings(
for i in intermediate_shardings:
if i is not None and type(i) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, new_avals, i, None)
new_shardings, [None] * len(new_shardings), i, None)
# For nullary cases like: `jit(lambda: ..., out_shardings=(None, sharding))`
for oi in new_shardings:
if oi is not None and type(oi) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, [None] * len(new_shardings), oi, None)
if context_mesh is not None and not context_mesh.empty:
return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh)

View File

@ -4346,6 +4346,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
out2 = compiled2(jnp.arange(8))
self.assertArraysEqual(out2, np.arange(8) * 2)
def test_nullary_out_sharding_partial(self):
mesh = jtu.create_mesh((jax.device_count(),), 'x')
@partial(jax.jit, out_shardings=(None, NamedSharding(mesh, P())))
def init():
tensor = jnp.zeros(shape=(1,))
other_tensor = jnp.zeros(shape=(1,))
return tensor, other_tensor
out = init()
self.assertIsInstance(out[0].sharding, NamedSharding)
self.assertIsInstance(out[1].sharding, NamedSharding)
def test_device_put_efficient_reshard_single_host(self):
if jax.device_count() < 4:
self.skipTest('Requires >= 4 devices')