From 1ae37b413132c4e0b6c046c7cfe4d62afe6940d8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 4 Aug 2023 16:26:31 -0700 Subject: [PATCH] Canonicalize to default memory in init of Shardings only on the backends that support memories right now. PiperOrigin-RevId: 553942534 --- jax/BUILD | 1 + jax/_src/interpreters/pxla.py | 25 +++++-------- jax/_src/maps.py | 7 ++-- jax/_src/pjit.py | 7 ++-- jax/_src/sharding_impls.py | 67 ++++++++++++++--------------------- 5 files changed, 42 insertions(+), 65 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 57bf6266f..d281a247a 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -567,6 +567,7 @@ pytype_strict_library( srcs = ["_src/sharding_impls.py"], visibility = [":internal"] + jax_visibility("sharding_impls"), deps = [ + ":config", ":mesh", ":op_shardings", ":partition_spec", diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 53f3be505..3a0219c47 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -24,7 +24,6 @@ from functools import partial, lru_cache, cached_property import itertools as it import logging import math -import os from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar) import numpy as np @@ -38,7 +37,6 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu -from jax._src import config as jax_config from jax._src import mesh as mesh_lib from jax._src import op_shardings from jax._src import sharding_specs @@ -71,13 +69,6 @@ from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) -# TODO(yashkatariya): Remove this flag after the host runtime is linked by -# default and works on cloud TPU. -_FETCH_MEMORY_KIND_ON_EXECUTABLE = jax_config.DEFINE_bool( - 'jax_fetch_memory_kind_on_executable', - bool(os.getenv('JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE', '')), - help=("If True, will allow fetching memory kinds available on executable " - "and annotate Shardings with it.")) # Built in Python lists don't support weak refs but subclasses of lists do. class WeakRefList(list): @@ -1700,7 +1691,7 @@ class SemanticallyEqualShardings: return False return all( (op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding) - and sharding_impls.are_mem_kind_of_shardings_equal(s, o)) + and s.memory_kind == o.memory_kind) if (isinstance(s, sharding_impls.GSPMDSharding) and isinstance(o, sharding_impls.GSPMDSharding)) else s == o @@ -2225,7 +2216,7 @@ def get_gspmd_shardings_from_executable( ) -> Sequence[sharding_impls.XLACompatibleSharding]: from jax._src import pjit - if _FETCH_MEMORY_KIND_ON_EXECUTABLE.value: + if sharding_impls._ENABLE_MEMORY_KIND.value: try: omk = xla_executable.get_output_memory_kinds()[0] except: @@ -2318,10 +2309,10 @@ def _get_out_sharding_from_orig_sharding( # replicated then, it doesn't encode the ndim in it. The devices # will be the same at this point because those checks happen before. if (orig_aval is not None and out_aval is not None and - out_aval.ndim == orig_aval.ndim and - sharding_impls.are_op_shardings_equal( - o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and - sharding_impls.are_mem_kind_of_shardings_equal(o, orig_in_s)): + out_aval.ndim == orig_aval.ndim + and sharding_impls.are_op_shardings_equal( + o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) + and o.memory_kind == orig_in_s.memory_kind): out.append((orig_in_s, False)) else: out.append((orig_handler(o, orig_in_s), False)) @@ -2553,7 +2544,7 @@ class UnloadedMeshExecutable: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or - not sharding_impls.are_mem_kind_of_shardings_equal(xla_s, orig)): # type: ignore + xla_s.memory_kind != orig.memory_kind): # type: ignore raise AssertionError( f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " "(User sharding)") @@ -2873,7 +2864,7 @@ def check_gda_or_array_xla_sharding_match( continue # Raise memory kind mismatch error even if the arg is uncommitted. - if not sharding_impls.are_mem_kind_of_shardings_equal(arg.sharding, xs): + if arg.sharding.memory_kind != xs.memory_kind: errors.append( f"Got Array sharding: {arg.sharding} and input sharding: {xs} for " f"arg {name} with shape: {arg.aval.str_short()}") diff --git a/jax/_src/maps.py b/jax/_src/maps.py index d1be7f1ce..20cf8490b 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1742,10 +1742,9 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env, @lru_cache def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor): if (not op_shardings.are_op_shardings_equal( - in_sharding._to_xla_hlo_sharding(ndim), - xmap_sharding._to_xla_hlo_sharding(ndim)) or - not sharding_impls.are_mem_kind_of_shardings_equal( - in_sharding, xmap_sharding)): + in_sharding._to_xla_hlo_sharding(ndim), + xmap_sharding._to_xla_hlo_sharding(ndim)) or + in_sharding.memory_kind != xmap_sharding.memory_kind): raise ValueError( f"Got an input {arr_flavor} to xmap with different partitioning than " "specified in xmap. The partitioning must match. " diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 7bfcca8ce..33af3151a 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -59,8 +59,7 @@ from jax._src.sharding_impls import ( XLADeviceAssignment, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified, - is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding, - are_mem_kind_of_shardings_equal) + is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, @@ -1095,7 +1094,7 @@ def _resolve_in_shardings( # jax.jit does not allow resharding across different memory kinds even # if the argument is uncommitted. Use jax.device_put for those cases, # either outside or inside jax.jit. - if not are_mem_kind_of_shardings_equal(pjit_in_s, arg_s): # type: ignore + if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore raise ValueError( 'Memory kinds passed to jax.jit does not match memory kind on the' f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore @@ -1242,7 +1241,7 @@ class SameDeviceAssignmentTuple: if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding): eq.append( op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding) - and are_mem_kind_of_shardings_equal(s, o)) + and s.memory_kind == o.memory_kind) else: eq.append(s == o) return all(eq) and self.device_assignment == other.device_assignment diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 09292de86..24889aa54 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -22,6 +22,7 @@ import enum import functools import itertools import math +import os from typing import Any, NamedTuple, Union, cast, Optional from jax._src import mesh as mesh_lib @@ -30,11 +31,13 @@ from jax._src.op_shardings import ( op_sharding_to_indices) from jax._src import sharding from jax._src import sharding_specs +from jax._src import config as jax_config from jax._src import tree_util from jax._src import util from jax._src import xla_bridge from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.partition_spec import PartitionSpec import numpy as np @@ -46,6 +49,15 @@ Index = tuple[slice, ...] XLADeviceAssignment = tuple[Device, ...] +# TODO(yashkatariya): Remove this flag after the host runtime is linked by +# default and works on cloud TPU. +_ENABLE_MEMORY_KIND = jax_config.DEFINE_bool( + 'jax_enable_memory_kind', + bool(os.getenv('JAX_ENABLE_MEMORY_KIND', '')), + help=("If True, will allow fetching memory kinds available on executable " + "and annotate Shardings with it.")) + + # Shardings that inherit from XLACompatibleSharding should implement the # `_device_assignment` property and `_to_xla_hlo_sharding` method. @use_cpp_class(xc.XLACompatibleSharding) @@ -109,7 +121,7 @@ class XLACompatibleSharding(sharding.Sharding): return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), other._to_xla_hlo_sharding(ndim)) and self._device_assignment == other._device_assignment and - are_mem_kind_of_shardings_equal(self, other)) + self.memory_kind == other.memory_kind) # NotImplementedError is raised by PmapSharding because it can't lower # to OpSharding. So if `other` is a PmapSharding, default to a strict # equality check. @@ -163,13 +175,6 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] def _mem_kinds(client: xc.Client) -> set[str]: return set(m.kind for m in client.local_devices()[0].addressable_memories()) -@functools.lru_cache -def _default_mem_kind(client: xc.Client) -> str | None: - try: - return client.local_devices()[0].default_memory().kind - except: - return None - def _check_mem_kind(device: xc.Device, mk): mem_kinds = _mem_kinds(device.client) if mk not in mem_kinds: @@ -179,26 +184,6 @@ def _check_mem_kind(device: xc.Device, mk): f' {mem_kinds}. Got memory kind: {mk}') -def get_canonicalized_memory_kind(s: XLACompatibleSharding) -> Optional[str]: - # TODO(yashkatariya): Remove try;except when CPU and GPU support memories. - try: - client = s._device_assignment[0].client - return _default_mem_kind(client) if s.memory_kind is None else s.memory_kind - except: - return None - -# TODO(yashkatariya): Remove this when the canonicalization happens in __init__ -# of Shardings. That can be done after OSS support is also added for memories. -def are_mem_kind_of_shardings_equal(s1: XLACompatibleSharding, - s2: XLACompatibleSharding) -> bool: - if s1.memory_kind is None and s2.memory_kind is None: - return True - - mk1 = get_canonicalized_memory_kind(s1) - mk2 = get_canonicalized_memory_kind(s2) - return mk1 == mk2 - - @use_cpp_class(xc.NamedSharding) class NamedSharding(XLACompatibleSharding): r"""A :class:`NamedSharding` expresses sharding using named axes. @@ -281,8 +266,7 @@ class NamedSharding(XLACompatibleSharding): def __hash__(self): if not hasattr(self, '_hash'): - self._hash = hash((self.mesh, get_canonicalized_memory_kind(self), - self._parsed_pspec)) + self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec)) return self._hash def __eq__(self, other): @@ -291,7 +275,7 @@ class NamedSharding(XLACompatibleSharding): if id(self) == id(other): return True parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec - mem_kind_equal = are_mem_kind_of_shardings_equal(self, other) + mem_kind_equal = self.memory_kind == other.memory_kind if (id(self.mesh) == id(other.mesh) and mem_kind_equal and parsed_pspec_equal): return True @@ -408,7 +392,7 @@ class SingleDeviceSharding(XLACompatibleSharding): def __hash__(self): if not hasattr(self, '_hash'): - self._hash = hash((self._device, get_canonicalized_memory_kind(self))) + self._hash = hash((self._device, self.memory_kind)) return self._hash def __eq__(self, other): @@ -417,7 +401,7 @@ class SingleDeviceSharding(XLACompatibleSharding): if id(self) == id(other): return True return (self._device == other._device and - are_mem_kind_of_shardings_equal(self, other)) + self.memory_kind == other.memory_kind) @property def device_set(self) -> set[Device]: @@ -628,6 +612,9 @@ class PositionalSharding(XLACompatibleSharding): if self._memory_kind is not None: # Will error if memory_kind does not exist on the device. _check_mem_kind(self._devices[0], self._memory_kind) + if xla_extension_version >= 177: + self._memory_kind = xc.canonicalize_memory_kind( + self._memory_kind, self._devices[0]) @property def shape(self): @@ -673,7 +660,7 @@ class PositionalSharding(XLACompatibleSharding): def __hash__(self) -> int: if not hasattr(self, '_hash'): - self._hash = hash((self._devices, get_canonicalized_memory_kind(self))) + self._hash = hash((self._devices, self.memory_kind)) return self._hash def __eq__(self, other) -> bool: @@ -682,7 +669,7 @@ class PositionalSharding(XLACompatibleSharding): if id(self) == id(other): return True all_ids_equal = np.array_equal(self._ids,other._ids) - mem_kind_equal = are_mem_kind_of_shardings_equal(self, other) + mem_kind_equal = self.memory_kind == other.memory_kind if (id(self._devices) == id(other._devices) and mem_kind_equal and all_ids_equal): return True @@ -779,9 +766,9 @@ class GSPMDSharding(XLACompatibleSharding): self._memory_kind = memory_kind def _preprocess(self): - if self.memory_kind is not None: + if self._memory_kind is not None: # Will error if memory_kind does not exist on the device. - _check_mem_kind(self._devices[0], self.memory_kind) + _check_mem_kind(self._devices[0], self._memory_kind) def __reduce__(self): return (type(self), (self._devices, self._hlo_sharding.to_proto()), @@ -798,12 +785,12 @@ class GSPMDSharding(XLACompatibleSharding): return True return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding) and self._devices == other._devices and - are_mem_kind_of_shardings_equal(self, other)) + self.memory_kind == other.memory_kind) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash((self._devices, self._hlo_sharding_hash, - get_canonicalized_memory_kind(self))) + self.memory_kind)) return self._hash def __repr__(self): @@ -850,7 +837,7 @@ class GSPMDSharding(XLACompatibleSharding): @classmethod def get_replicated(cls, device_assignment, *, memory_kind: str | None = None): return cls(tuple(device_assignment), get_replicated_hlo_sharding(), - memory_kind=memory_kind) + memory_kind=memory_kind) class AUTO: