mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Preserve PositionalSharding on the output of pjits if the inputs had PositionalSharding on them by converting GSPMDSharding to PositionalSharding
PiperOrigin-RevId: 523535581
This commit is contained in:
parent
1802923b56
commit
cf2f182a6c
@ -25,7 +25,7 @@ import logging
|
||||
import math
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable,
|
||||
TYPE_CHECKING, cast)
|
||||
TYPE_CHECKING, cast, TypeVar)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -2575,38 +2575,58 @@ def _get_mesh_pspec_shardings_from_executable(
|
||||
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
|
||||
|
||||
|
||||
def _get_out_sharding_from_named_sharding(
|
||||
out_shardings, ns, are_out_sharding_from_xla):
|
||||
from jax._src import pjit
|
||||
SubClassT = TypeVar("SubClassT", bound=sharding_impls.XLACompatibleSharding)
|
||||
OrigHandlerType = Dict[Type[SubClassT],
|
||||
Callable[[xc.OpSharding, SubClassT], SubClassT]]
|
||||
|
||||
orig_out_sharding_handlers: OrigHandlerType = {}
|
||||
|
||||
def _gspmd_to_named_sharding(
|
||||
op_sharding: xc.OpSharding,
|
||||
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
||||
from jax._src.pjit import parse_flatten_op_sharding
|
||||
return sharding_impls.NamedSharding._from_parsed_pspec(
|
||||
self.mesh, parse_flatten_op_sharding(op_sharding, self.mesh)[0])
|
||||
orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_sharding
|
||||
|
||||
|
||||
def _gspmd_to_positional_sharding(
|
||||
op_sharding: xc.OpSharding,
|
||||
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
|
||||
return sharding_impls._from_op_sharding_to_pos_sharding(
|
||||
op_sharding, self._device_assignment)
|
||||
orig_out_sharding_handlers[sharding_impls.PositionalSharding] = _gspmd_to_positional_sharding
|
||||
|
||||
|
||||
def _get_out_sharding_from_orig_sharding(
|
||||
out_shardings, orig_s, are_out_sharding_from_xla):
|
||||
out = []
|
||||
orig_handler = orig_out_sharding_handlers[type(orig_s)]
|
||||
for o, from_xla in safe_zip(out_shardings, are_out_sharding_from_xla):
|
||||
if isinstance(o, sharding_impls.GSPMDSharding):
|
||||
try:
|
||||
out.append((sharding_impls.NamedSharding._from_parsed_pspec(
|
||||
ns.mesh, pjit.parse_flatten_op_sharding(o._op_sharding, ns.mesh)[0]), False))
|
||||
out.append((orig_handler(o._op_sharding, orig_s), False))
|
||||
except:
|
||||
out.append((o, from_xla))
|
||||
else:
|
||||
out.append((o, from_xla))
|
||||
return out
|
||||
|
||||
|
||||
def maybe_get_orig_out_sharding(
|
||||
in_shardings, out_shardings, are_out_shardings_from_xla):
|
||||
if all(hasattr(o, '_original_sharding') for o in out_shardings):
|
||||
return ([o._original_sharding for o in out_shardings],
|
||||
(False,) * len(out_shardings))
|
||||
|
||||
# TODO(yashkatariya): Handle other shardings too here.
|
||||
ns = None
|
||||
orig_s = None
|
||||
for i in in_shardings:
|
||||
oi = getattr(i, '_original_sharding', None)
|
||||
if isinstance(oi, sharding_impls.NamedSharding):
|
||||
ns = oi
|
||||
if type(oi) in orig_out_sharding_handlers:
|
||||
orig_s = oi
|
||||
break
|
||||
if ns is not None:
|
||||
return zip(*_get_out_sharding_from_named_sharding(
|
||||
out_shardings, ns, are_out_shardings_from_xla))
|
||||
if orig_s is not None:
|
||||
return zip(*_get_out_sharding_from_orig_sharding(
|
||||
out_shardings, orig_s, are_out_shardings_from_xla))
|
||||
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
|
||||
|
@ -524,11 +524,17 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
# Hashable
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self._devices)
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash(tuple(self._devices))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return (isinstance(other, PositionalSharding) and
|
||||
id(self._devices) == id(other._devices) and
|
||||
if not isinstance(other, PositionalSharding):
|
||||
return False
|
||||
if (id(self._devices) == id(other._devices) and
|
||||
bool(np.all(self._ids == other._ids))):
|
||||
return True
|
||||
return (self._devices == other._devices and
|
||||
bool(np.all(self._ids == other._ids)))
|
||||
|
||||
# Sharding interface
|
||||
|
@ -2979,9 +2979,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
out4 = jnp.squeeze(arr2, axis=-1)
|
||||
cache_info4 = pxla._cached_compilation.cache_info()
|
||||
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
|
||||
# GSPMDShardings can be converted to PositionalSharding.
|
||||
self.assertIsInstance(out4.sharding, GSPMDSharding)
|
||||
self.assertIsInstance(out4.sharding, PositionalSharding)
|
||||
|
||||
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
|
||||
self.assertEqual(cache_info4.misses, cache_info3.misses)
|
||||
@ -3083,27 +3081,22 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertIsInstance(out.sharding, NamedSharding)
|
||||
|
||||
out2 = f(arr2)
|
||||
cache_info2 = pxla._cached_compilation.cache_info()
|
||||
pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
|
||||
# GSPMDShardings can be converted to PositionalSharding.
|
||||
self.assertIsInstance(out2.sharding, GSPMDSharding)
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
out2 = f(arr2)
|
||||
cache_info2 = pxla._cached_compilation.cache_info()
|
||||
pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertIsInstance(out2.sharding, PositionalSharding)
|
||||
|
||||
out3 = f(out2)
|
||||
cache_info3 = pxla._cached_compilation.cache_info()
|
||||
pl_cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
self.assertIsInstance(out3.sharding, GSPMDSharding)
|
||||
# This will hit the cpp cache.
|
||||
out3 = f(out2)
|
||||
self.assertIsInstance(out3.sharding, PositionalSharding)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
self.assertEqual(cache_info3.misses, cache_info2.misses)
|
||||
|
||||
# TODO(yashkatariya): We will get hits here after we can convert
|
||||
# GSPMDSharding to PositionalSharding.
|
||||
self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits)
|
||||
self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1)
|
||||
self.assertEqual(pl_cache_info3.misses, pl_cache_info2.misses + 1)
|
||||
|
||||
out4 = jnp.sum(arr)
|
||||
self.assertIsInstance(out4.sharding, NamedSharding)
|
||||
@ -3159,9 +3152,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1)
|
||||
arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps)
|
||||
out2 = jnp.copy(arr2)
|
||||
# TODO(yashkatariya): Handle PositionalSharding inside pxla so that
|
||||
# GSPMDShardings can be converted to PositionalSharding.
|
||||
self.assertIsInstance(out2.sharding, GSPMDSharding)
|
||||
self.assertIsInstance(out2.sharding, PositionalSharding)
|
||||
|
||||
arr3 = jnp.arange(8)
|
||||
out3 = jnp.copy(arr3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user