Preserve PositionalSharding on the output of pjits if the inputs had PositionalSharding on them by converting GSPMDSharding to PositionalSharding

PiperOrigin-RevId: 523535581
This commit is contained in:
Yash Katariya 2023-04-11 16:27:08 -07:00 committed by jax authors
parent 1802923b56
commit cf2f182a6c
3 changed files with 55 additions and 38 deletions

View File

@ -25,7 +25,7 @@ import logging
import math
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable,
TYPE_CHECKING, cast)
TYPE_CHECKING, cast, TypeVar)
import numpy as np
@ -2575,38 +2575,58 @@ def _get_mesh_pspec_shardings_from_executable(
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
def _get_out_sharding_from_named_sharding(
out_shardings, ns, are_out_sharding_from_xla):
from jax._src import pjit
SubClassT = TypeVar("SubClassT", bound=sharding_impls.XLACompatibleSharding)
OrigHandlerType = Dict[Type[SubClassT],
Callable[[xc.OpSharding, SubClassT], SubClassT]]
orig_out_sharding_handlers: OrigHandlerType = {}
def _gspmd_to_named_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
from jax._src.pjit import parse_flatten_op_sharding
return sharding_impls.NamedSharding._from_parsed_pspec(
self.mesh, parse_flatten_op_sharding(op_sharding, self.mesh)[0])
orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_sharding
def _gspmd_to_positional_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
return sharding_impls._from_op_sharding_to_pos_sharding(
op_sharding, self._device_assignment)
orig_out_sharding_handlers[sharding_impls.PositionalSharding] = _gspmd_to_positional_sharding
def _get_out_sharding_from_orig_sharding(
out_shardings, orig_s, are_out_sharding_from_xla):
out = []
orig_handler = orig_out_sharding_handlers[type(orig_s)]
for o, from_xla in safe_zip(out_shardings, are_out_sharding_from_xla):
if isinstance(o, sharding_impls.GSPMDSharding):
try:
out.append((sharding_impls.NamedSharding._from_parsed_pspec(
ns.mesh, pjit.parse_flatten_op_sharding(o._op_sharding, ns.mesh)[0]), False))
out.append((orig_handler(o._op_sharding, orig_s), False))
except:
out.append((o, from_xla))
else:
out.append((o, from_xla))
return out
def maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla):
if all(hasattr(o, '_original_sharding') for o in out_shardings):
return ([o._original_sharding for o in out_shardings],
(False,) * len(out_shardings))
# TODO(yashkatariya): Handle other shardings too here.
ns = None
orig_s = None
for i in in_shardings:
oi = getattr(i, '_original_sharding', None)
if isinstance(oi, sharding_impls.NamedSharding):
ns = oi
if type(oi) in orig_out_sharding_handlers:
orig_s = oi
break
if ns is not None:
return zip(*_get_out_sharding_from_named_sharding(
out_shardings, ns, are_out_shardings_from_xla))
if orig_s is not None:
return zip(*_get_out_sharding_from_orig_sharding(
out_shardings, orig_s, are_out_shardings_from_xla))
return out_shardings, are_out_shardings_from_xla

View File

@ -524,11 +524,17 @@ class PositionalSharding(XLACompatibleSharding):
# Hashable
def __hash__(self) -> int:
return id(self._devices)
if not hasattr(self, '_hash'):
self._hash = hash(tuple(self._devices))
return self._hash
def __eq__(self, other) -> bool:
return (isinstance(other, PositionalSharding) and
id(self._devices) == id(other._devices) and
if not isinstance(other, PositionalSharding):
return False
if (id(self._devices) == id(other._devices) and
bool(np.all(self._ids == other._ids))):
return True
return (self._devices == other._devices and
bool(np.all(self._ids == other._ids)))
# Sharding interface

View File

@ -2979,9 +2979,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out4 = jnp.squeeze(arr2, axis=-1)
cache_info4 = pxla._cached_compilation.cache_info()
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
# GSPMDShardings can be converted to PositionalSharding.
self.assertIsInstance(out4.sharding, GSPMDSharding)
self.assertIsInstance(out4.sharding, PositionalSharding)
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
self.assertEqual(cache_info4.misses, cache_info3.misses)
@ -3083,27 +3081,22 @@ class ArrayPjitTest(jtu.JaxTestCase):
pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out.sharding, NamedSharding)
out2 = f(arr2)
cache_info2 = pxla._cached_compilation.cache_info()
pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
# GSPMDShardings can be converted to PositionalSharding.
self.assertIsInstance(out2.sharding, GSPMDSharding)
with jtu.count_pjit_cpp_cache_miss() as count:
out2 = f(arr2)
cache_info2 = pxla._cached_compilation.cache_info()
pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out2.sharding, PositionalSharding)
out3 = f(out2)
cache_info3 = pxla._cached_compilation.cache_info()
pl_cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
self.assertIsInstance(out3.sharding, GSPMDSharding)
# This will hit the cpp cache.
out3 = f(out2)
self.assertIsInstance(out3.sharding, PositionalSharding)
self.assertEqual(count[0], 1)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
self.assertEqual(cache_info3.misses, cache_info2.misses)
# TODO(yashkatariya): We will get hits here after we can convert
# GSPMDSharding to PositionalSharding.
self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits)
self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1)
self.assertEqual(pl_cache_info3.misses, pl_cache_info2.misses + 1)
out4 = jnp.sum(arr)
self.assertIsInstance(out4.sharding, NamedSharding)
@ -3159,9 +3152,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
out2 = jnp.copy(arr2)
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
# GSPMDShardings can be converted to PositionalSharding.
self.assertIsInstance(out2.sharding, GSPMDSharding)
self.assertIsInstance(out2.sharding, PositionalSharding)
arr3 = jnp.arange(8)
out3 = jnp.copy(arr3)