From 6b253b2f754588fcb0bd3b9851f7a1e536919094 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sun, 12 Jan 2025 08:08:35 -0800 Subject: [PATCH] [shardy] Fix cases in shardy where you have a nullary function with partially specified out_shardings (i.e. some out_sharding's are None and others are NamedShardings). In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's). PiperOrigin-RevId: 714681941 --- jax/_src/interpreters/pxla.py | 36 ++++++++++++++++------------------- tests/pjit_test.py | 13 +++++++++++++ 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 150db8a44..ffb1662d8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -25,7 +25,7 @@ import functools import itertools as it import logging import math -from typing import Any, NamedTuple, TypeVar, Union, cast +from typing import Any, NamedTuple, Union, cast import warnings import numpy as np @@ -2525,39 +2525,29 @@ def _get_mesh_pspec_shardings_from_executable( _orig_out_sharding_handlers = {} -_ShardingT = TypeVar("_ShardingT", bound=JSharding) - - -def _register_out_sharding_handler( - sharding_cls: type[_ShardingT], - handler: Callable[[GSPMDSharding, _ShardingT], _ShardingT], -) -> None: - _orig_out_sharding_handlers[sharding_cls] = handler - def _gspmd_to_named_sharding( out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding: + assert isinstance(out_s, GSPMDSharding) + assert isinstance(orig_in_s, NamedSharding) assert isinstance(orig_in_s.mesh, Mesh) return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) - -_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding) - +_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding # type: ignore def _gspmd_to_positional_sharding( out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding: + assert isinstance(out_s, GSPMDSharding) + assert isinstance(orig_in_s, PositionalSharding) return sharding_impls._op_sharding_to_pos_sharding( out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind) - -_register_out_sharding_handler( - PositionalSharding, _gspmd_to_positional_sharding) +_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore def _gspmd_to_single_device_sharding( out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding: + assert isinstance(out_s, GSPMDSharding) assert isinstance(orig_in_s, SingleDeviceSharding) return SingleDeviceSharding( out_s._device_assignment[0], memory_kind=out_s.memory_kind) - -_register_out_sharding_handler( - SingleDeviceSharding, _gspmd_to_single_device_sharding) +_orig_out_sharding_handlers[SingleDeviceSharding] = _gspmd_to_single_device_sharding # type: ignore def _get_out_sharding_from_orig_sharding( @@ -2601,7 +2591,13 @@ def maybe_recover_user_shardings( for i in intermediate_shardings: if i is not None and type(i) in _orig_out_sharding_handlers: return _get_out_sharding_from_orig_sharding( - new_shardings, new_avals, i, None) + new_shardings, [None] * len(new_shardings), i, None) + + # For nullary cases like: `jit(lambda: ..., out_shardings=(None, sharding))` + for oi in new_shardings: + if oi is not None and type(oi) in _orig_out_sharding_handlers: + return _get_out_sharding_from_orig_sharding( + new_shardings, [None] * len(new_shardings), oi, None) if context_mesh is not None and not context_mesh.empty: return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 95e639e17..3b11a0cba 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4346,6 +4346,19 @@ class ArrayPjitTest(jtu.JaxTestCase): out2 = compiled2(jnp.arange(8)) self.assertArraysEqual(out2, np.arange(8) * 2) + def test_nullary_out_sharding_partial(self): + mesh = jtu.create_mesh((jax.device_count(),), 'x') + + @partial(jax.jit, out_shardings=(None, NamedSharding(mesh, P()))) + def init(): + tensor = jnp.zeros(shape=(1,)) + other_tensor = jnp.zeros(shape=(1,)) + return tensor, other_tensor + + out = init() + self.assertIsInstance(out[0].sharding, NamedSharding) + self.assertIsInstance(out[1].sharding, NamedSharding) + def test_device_put_efficient_reshard_single_host(self): if jax.device_count() < 4: self.skipTest('Requires >= 4 devices')