mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[shardy] Fix cases in shardy where you have a nullary function with partially specified out_shardings (i.e. some out_sharding's are None and others are NamedShardings).
In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's). PiperOrigin-RevId: 714681941
This commit is contained in:
parent
4e7293162b
commit
6b253b2f75
@ -25,7 +25,7 @@ import functools
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, NamedTuple, TypeVar, Union, cast
|
||||
from typing import Any, NamedTuple, Union, cast
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -2525,39 +2525,29 @@ def _get_mesh_pspec_shardings_from_executable(
|
||||
|
||||
_orig_out_sharding_handlers = {}
|
||||
|
||||
_ShardingT = TypeVar("_ShardingT", bound=JSharding)
|
||||
|
||||
|
||||
def _register_out_sharding_handler(
|
||||
sharding_cls: type[_ShardingT],
|
||||
handler: Callable[[GSPMDSharding, _ShardingT], _ShardingT],
|
||||
) -> None:
|
||||
_orig_out_sharding_handlers[sharding_cls] = handler
|
||||
|
||||
def _gspmd_to_named_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
|
||||
assert isinstance(out_s, GSPMDSharding)
|
||||
assert isinstance(orig_in_s, NamedSharding)
|
||||
assert isinstance(orig_in_s.mesh, Mesh)
|
||||
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
|
||||
|
||||
_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)
|
||||
|
||||
_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding # type: ignore
|
||||
|
||||
def _gspmd_to_positional_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding:
|
||||
assert isinstance(out_s, GSPMDSharding)
|
||||
assert isinstance(orig_in_s, PositionalSharding)
|
||||
return sharding_impls._op_sharding_to_pos_sharding(
|
||||
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
|
||||
|
||||
_register_out_sharding_handler(
|
||||
PositionalSharding, _gspmd_to_positional_sharding)
|
||||
_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore
|
||||
|
||||
def _gspmd_to_single_device_sharding(
|
||||
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
|
||||
assert isinstance(out_s, GSPMDSharding)
|
||||
assert isinstance(orig_in_s, SingleDeviceSharding)
|
||||
return SingleDeviceSharding(
|
||||
out_s._device_assignment[0], memory_kind=out_s.memory_kind)
|
||||
|
||||
_register_out_sharding_handler(
|
||||
SingleDeviceSharding, _gspmd_to_single_device_sharding)
|
||||
_orig_out_sharding_handlers[SingleDeviceSharding] = _gspmd_to_single_device_sharding # type: ignore
|
||||
|
||||
|
||||
def _get_out_sharding_from_orig_sharding(
|
||||
@ -2601,7 +2591,13 @@ def maybe_recover_user_shardings(
|
||||
for i in intermediate_shardings:
|
||||
if i is not None and type(i) in _orig_out_sharding_handlers:
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
new_shardings, new_avals, i, None)
|
||||
new_shardings, [None] * len(new_shardings), i, None)
|
||||
|
||||
# For nullary cases like: `jit(lambda: ..., out_shardings=(None, sharding))`
|
||||
for oi in new_shardings:
|
||||
if oi is not None and type(oi) in _orig_out_sharding_handlers:
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
new_shardings, [None] * len(new_shardings), oi, None)
|
||||
|
||||
if context_mesh is not None and not context_mesh.empty:
|
||||
return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh)
|
||||
|
@ -4346,6 +4346,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out2 = compiled2(jnp.arange(8))
|
||||
self.assertArraysEqual(out2, np.arange(8) * 2)
|
||||
|
||||
def test_nullary_out_sharding_partial(self):
|
||||
mesh = jtu.create_mesh((jax.device_count(),), 'x')
|
||||
|
||||
@partial(jax.jit, out_shardings=(None, NamedSharding(mesh, P())))
|
||||
def init():
|
||||
tensor = jnp.zeros(shape=(1,))
|
||||
other_tensor = jnp.zeros(shape=(1,))
|
||||
return tensor, other_tensor
|
||||
|
||||
out = init()
|
||||
self.assertIsInstance(out[0].sharding, NamedSharding)
|
||||
self.assertIsInstance(out[1].sharding, NamedSharding)
|
||||
|
||||
def test_device_put_efficient_reshard_single_host(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires >= 4 devices')
|
||||
|
Loading…
x
Reference in New Issue
Block a user