diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3fb0366a7..12334646e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -25,7 +25,7 @@ import logging import math from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, Sequence, Set, Tuple, Type, Union, Iterable, - TYPE_CHECKING, cast) + TYPE_CHECKING, cast, TypeVar) import numpy as np @@ -2575,38 +2575,58 @@ def _get_mesh_pspec_shardings_from_executable( [sharding_impls.NamedSharding(mesh, o) for o in out_pspec]) -def _get_out_sharding_from_named_sharding( - out_shardings, ns, are_out_sharding_from_xla): - from jax._src import pjit +SubClassT = TypeVar("SubClassT", bound=sharding_impls.XLACompatibleSharding) +OrigHandlerType = Dict[Type[SubClassT], + Callable[[xc.OpSharding, SubClassT], SubClassT]] + +orig_out_sharding_handlers: OrigHandlerType = {} + +def _gspmd_to_named_sharding( + op_sharding: xc.OpSharding, + self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: + from jax._src.pjit import parse_flatten_op_sharding + return sharding_impls.NamedSharding._from_parsed_pspec( + self.mesh, parse_flatten_op_sharding(op_sharding, self.mesh)[0]) +orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_sharding + + +def _gspmd_to_positional_sharding( + op_sharding: xc.OpSharding, + self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding: + return sharding_impls._from_op_sharding_to_pos_sharding( + op_sharding, self._device_assignment) +orig_out_sharding_handlers[sharding_impls.PositionalSharding] = _gspmd_to_positional_sharding + + +def _get_out_sharding_from_orig_sharding( + out_shardings, orig_s, are_out_sharding_from_xla): out = [] + orig_handler = orig_out_sharding_handlers[type(orig_s)] for o, from_xla in safe_zip(out_shardings, are_out_sharding_from_xla): if isinstance(o, sharding_impls.GSPMDSharding): try: - out.append((sharding_impls.NamedSharding._from_parsed_pspec( - ns.mesh, pjit.parse_flatten_op_sharding(o._op_sharding, ns.mesh)[0]), False)) + out.append((orig_handler(o._op_sharding, orig_s), False)) except: out.append((o, from_xla)) else: out.append((o, from_xla)) return out - def maybe_get_orig_out_sharding( in_shardings, out_shardings, are_out_shardings_from_xla): if all(hasattr(o, '_original_sharding') for o in out_shardings): return ([o._original_sharding for o in out_shardings], (False,) * len(out_shardings)) - # TODO(yashkatariya): Handle other shardings too here. - ns = None + orig_s = None for i in in_shardings: oi = getattr(i, '_original_sharding', None) - if isinstance(oi, sharding_impls.NamedSharding): - ns = oi + if type(oi) in orig_out_sharding_handlers: + orig_s = oi break - if ns is not None: - return zip(*_get_out_sharding_from_named_sharding( - out_shardings, ns, are_out_shardings_from_xla)) + if orig_s is not None: + return zip(*_get_out_sharding_from_orig_sharding( + out_shardings, orig_s, are_out_shardings_from_xla)) return out_shardings, are_out_shardings_from_xla diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 9f37445c0..009b531ea 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -524,11 +524,17 @@ class PositionalSharding(XLACompatibleSharding): # Hashable def __hash__(self) -> int: - return id(self._devices) + if not hasattr(self, '_hash'): + self._hash = hash(tuple(self._devices)) + return self._hash def __eq__(self, other) -> bool: - return (isinstance(other, PositionalSharding) and - id(self._devices) == id(other._devices) and + if not isinstance(other, PositionalSharding): + return False + if (id(self._devices) == id(other._devices) and + bool(np.all(self._ids == other._ids))): + return True + return (self._devices == other._devices and bool(np.all(self._ids == other._ids))) # Sharding interface diff --git a/tests/pjit_test.py b/tests/pjit_test.py index bd35b81e2..1fe0ff73b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2979,9 +2979,7 @@ class ArrayPjitTest(jtu.JaxTestCase): out4 = jnp.squeeze(arr2, axis=-1) cache_info4 = pxla._cached_compilation.cache_info() - # TODO(yashkatariya): Handle PositionalSharding inside pxla so that - # GSPMDShardings can be converted to PositionalSharding. - self.assertIsInstance(out4.sharding, GSPMDSharding) + self.assertIsInstance(out4.sharding, PositionalSharding) self.assertEqual(cache_info4.hits, cache_info3.hits + 1) self.assertEqual(cache_info4.misses, cache_info3.misses) @@ -3083,27 +3081,22 @@ class ArrayPjitTest(jtu.JaxTestCase): pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info() self.assertIsInstance(out.sharding, NamedSharding) - out2 = f(arr2) - cache_info2 = pxla._cached_compilation.cache_info() - pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info() - # TODO(yashkatariya): Handle PositionalSharding inside pxla so that - # GSPMDShardings can be converted to PositionalSharding. - self.assertIsInstance(out2.sharding, GSPMDSharding) + with jtu.count_pjit_cpp_cache_miss() as count: + out2 = f(arr2) + cache_info2 = pxla._cached_compilation.cache_info() + pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + self.assertIsInstance(out2.sharding, PositionalSharding) - out3 = f(out2) - cache_info3 = pxla._cached_compilation.cache_info() - pl_cache_info3 = pjit_lib._pjit_lower_cached.cache_info() - self.assertIsInstance(out3.sharding, GSPMDSharding) + # This will hit the cpp cache. + out3 = f(out2) + self.assertIsInstance(out3.sharding, PositionalSharding) + self.assertEqual(count[0], 1) self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info3.hits, cache_info2.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) - self.assertEqual(cache_info3.misses, cache_info2.misses) - # TODO(yashkatariya): We will get hits here after we can convert - # GSPMDSharding to PositionalSharding. + self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits) self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1) - self.assertEqual(pl_cache_info3.misses, pl_cache_info2.misses + 1) out4 = jnp.sum(arr) self.assertIsInstance(out4.sharding, NamedSharding) @@ -3159,9 +3152,7 @@ class ArrayPjitTest(jtu.JaxTestCase): ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) out2 = jnp.copy(arr2) - # TODO(yashkatariya): Handle PositionalSharding inside pxla so that - # GSPMDShardings can be converted to PositionalSharding. - self.assertIsInstance(out2.sharding, GSPMDSharding) + self.assertIsInstance(out2.sharding, PositionalSharding) arr3 = jnp.arange(8) out3 = jnp.copy(arr3)