Canonicalize to default memory in init of Shardings only on the backends that support memories right now.

PiperOrigin-RevId: 553942534
This commit is contained in:
Yash Katariya 2023-08-04 16:26:31 -07:00 committed by jax authors
parent 73d1b26cf6
commit 1ae37b4131
5 changed files with 42 additions and 65 deletions

View File

@ -567,6 +567,7 @@ pytype_strict_library(
srcs = ["_src/sharding_impls.py"],
visibility = [":internal"] + jax_visibility("sharding_impls"),
deps = [
":config",
":mesh",
":op_shardings",
":partition_spec",

View File

@ -24,7 +24,6 @@ from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
import os
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
import numpy as np
@ -38,7 +37,6 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import config as jax_config
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_specs
@ -71,13 +69,6 @@ from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
# TODO(yashkatariya): Remove this flag after the host runtime is linked by
# default and works on cloud TPU.
_FETCH_MEMORY_KIND_ON_EXECUTABLE = jax_config.DEFINE_bool(
'jax_fetch_memory_kind_on_executable',
bool(os.getenv('JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE', '')),
help=("If True, will allow fetching memory kinds available on executable "
"and annotate Shardings with it."))
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
@ -1700,7 +1691,7 @@ class SemanticallyEqualShardings:
return False
return all(
(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
and sharding_impls.are_mem_kind_of_shardings_equal(s, o))
and s.memory_kind == o.memory_kind)
if (isinstance(s, sharding_impls.GSPMDSharding) and
isinstance(o, sharding_impls.GSPMDSharding))
else s == o
@ -2225,7 +2216,7 @@ def get_gspmd_shardings_from_executable(
) -> Sequence[sharding_impls.XLACompatibleSharding]:
from jax._src import pjit
if _FETCH_MEMORY_KIND_ON_EXECUTABLE.value:
if sharding_impls._ENABLE_MEMORY_KIND.value:
try:
omk = xla_executable.get_output_memory_kinds()[0]
except:
@ -2318,10 +2309,10 @@ def _get_out_sharding_from_orig_sharding(
# replicated then, it doesn't encode the ndim in it. The devices
# will be the same at this point because those checks happen before.
if (orig_aval is not None and out_aval is not None and
out_aval.ndim == orig_aval.ndim and
sharding_impls.are_op_shardings_equal(
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and
sharding_impls.are_mem_kind_of_shardings_equal(o, orig_in_s)):
out_aval.ndim == orig_aval.ndim
and sharding_impls.are_op_shardings_equal(
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
and o.memory_kind == orig_in_s.memory_kind):
out.append((orig_in_s, False))
else:
out.append((orig_handler(o, orig_in_s), False))
@ -2553,7 +2544,7 @@ class UnloadedMeshExecutable:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
not sharding_impls.are_mem_kind_of_shardings_equal(xla_s, orig)): # type: ignore
xla_s.memory_kind != orig.memory_kind): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
@ -2873,7 +2864,7 @@ def check_gda_or_array_xla_sharding_match(
continue
# Raise memory kind mismatch error even if the arg is uncommitted.
if not sharding_impls.are_mem_kind_of_shardings_equal(arg.sharding, xs):
if arg.sharding.memory_kind != xs.memory_kind:
errors.append(
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
f"arg {name} with shape: {arg.aval.str_short()}")

View File

@ -1744,8 +1744,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
if (not op_shardings.are_op_shardings_equal(
in_sharding._to_xla_hlo_sharding(ndim),
xmap_sharding._to_xla_hlo_sharding(ndim)) or
not sharding_impls.are_mem_kind_of_shardings_equal(
in_sharding, xmap_sharding)):
in_sharding.memory_kind != xmap_sharding.memory_kind):
raise ValueError(
f"Got an input {arr_flavor} to xmap with different partitioning than "
"specified in xmap. The partitioning must match. "

View File

@ -59,8 +59,7 @@ from jax._src.sharding_impls import (
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding,
are_mem_kind_of_shardings_equal)
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
@ -1095,7 +1094,7 @@ def _resolve_in_shardings(
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
if not are_mem_kind_of_shardings_equal(pjit_in_s, arg_s): # type: ignore
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
raise ValueError(
'Memory kinds passed to jax.jit does not match memory kind on the'
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
@ -1242,7 +1241,7 @@ class SameDeviceAssignmentTuple:
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding):
eq.append(
op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
and are_mem_kind_of_shardings_equal(s, o))
and s.memory_kind == o.memory_kind)
else:
eq.append(s == o)
return all(eq) and self.device_assignment == other.device_assignment

View File

@ -22,6 +22,7 @@ import enum
import functools
import itertools
import math
import os
from typing import Any, NamedTuple, Union, cast, Optional
from jax._src import mesh as mesh_lib
@ -30,11 +31,13 @@ from jax._src.op_shardings import (
op_sharding_to_indices)
from jax._src import sharding
from jax._src import sharding_specs
from jax._src import config as jax_config
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.partition_spec import PartitionSpec
import numpy as np
@ -46,6 +49,15 @@ Index = tuple[slice, ...]
XLADeviceAssignment = tuple[Device, ...]
# TODO(yashkatariya): Remove this flag after the host runtime is linked by
# default and works on cloud TPU.
_ENABLE_MEMORY_KIND = jax_config.DEFINE_bool(
'jax_enable_memory_kind',
bool(os.getenv('JAX_ENABLE_MEMORY_KIND', '')),
help=("If True, will allow fetching memory kinds available on executable "
"and annotate Shardings with it."))
# Shardings that inherit from XLACompatibleSharding should implement the
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
@use_cpp_class(xc.XLACompatibleSharding)
@ -109,7 +121,7 @@ class XLACompatibleSharding(sharding.Sharding):
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
other._to_xla_hlo_sharding(ndim))
and self._device_assignment == other._device_assignment and
are_mem_kind_of_shardings_equal(self, other))
self.memory_kind == other.memory_kind)
# NotImplementedError is raised by PmapSharding because it can't lower
# to OpSharding. So if `other` is a PmapSharding, default to a strict
# equality check.
@ -163,13 +175,6 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
def _mem_kinds(client: xc.Client) -> set[str]:
return set(m.kind for m in client.local_devices()[0].addressable_memories())
@functools.lru_cache
def _default_mem_kind(client: xc.Client) -> str | None:
try:
return client.local_devices()[0].default_memory().kind
except:
return None
def _check_mem_kind(device: xc.Device, mk):
mem_kinds = _mem_kinds(device.client)
if mk not in mem_kinds:
@ -179,26 +184,6 @@ def _check_mem_kind(device: xc.Device, mk):
f' {mem_kinds}. Got memory kind: {mk}')
def get_canonicalized_memory_kind(s: XLACompatibleSharding) -> Optional[str]:
# TODO(yashkatariya): Remove try;except when CPU and GPU support memories.
try:
client = s._device_assignment[0].client
return _default_mem_kind(client) if s.memory_kind is None else s.memory_kind
except:
return None
# TODO(yashkatariya): Remove this when the canonicalization happens in __init__
# of Shardings. That can be done after OSS support is also added for memories.
def are_mem_kind_of_shardings_equal(s1: XLACompatibleSharding,
s2: XLACompatibleSharding) -> bool:
if s1.memory_kind is None and s2.memory_kind is None:
return True
mk1 = get_canonicalized_memory_kind(s1)
mk2 = get_canonicalized_memory_kind(s2)
return mk1 == mk2
@use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding):
r"""A :class:`NamedSharding` expresses sharding using named axes.
@ -281,8 +266,7 @@ class NamedSharding(XLACompatibleSharding):
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self.mesh, get_canonicalized_memory_kind(self),
self._parsed_pspec))
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
return self._hash
def __eq__(self, other):
@ -291,7 +275,7 @@ class NamedSharding(XLACompatibleSharding):
if id(self) == id(other):
return True
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
mem_kind_equal = self.memory_kind == other.memory_kind
if (id(self.mesh) == id(other.mesh) and mem_kind_equal and
parsed_pspec_equal):
return True
@ -408,7 +392,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._device, get_canonicalized_memory_kind(self)))
self._hash = hash((self._device, self.memory_kind))
return self._hash
def __eq__(self, other):
@ -417,7 +401,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
if id(self) == id(other):
return True
return (self._device == other._device and
are_mem_kind_of_shardings_equal(self, other))
self.memory_kind == other.memory_kind)
@property
def device_set(self) -> set[Device]:
@ -628,6 +612,9 @@ class PositionalSharding(XLACompatibleSharding):
if self._memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(self._devices[0], self._memory_kind)
if xla_extension_version >= 177:
self._memory_kind = xc.canonicalize_memory_kind(
self._memory_kind, self._devices[0])
@property
def shape(self):
@ -673,7 +660,7 @@ class PositionalSharding(XLACompatibleSharding):
def __hash__(self) -> int:
if not hasattr(self, '_hash'):
self._hash = hash((self._devices, get_canonicalized_memory_kind(self)))
self._hash = hash((self._devices, self.memory_kind))
return self._hash
def __eq__(self, other) -> bool:
@ -682,7 +669,7 @@ class PositionalSharding(XLACompatibleSharding):
if id(self) == id(other):
return True
all_ids_equal = np.array_equal(self._ids,other._ids)
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
mem_kind_equal = self.memory_kind == other.memory_kind
if (id(self._devices) == id(other._devices) and mem_kind_equal and
all_ids_equal):
return True
@ -779,9 +766,9 @@ class GSPMDSharding(XLACompatibleSharding):
self._memory_kind = memory_kind
def _preprocess(self):
if self.memory_kind is not None:
if self._memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(self._devices[0], self.memory_kind)
_check_mem_kind(self._devices[0], self._memory_kind)
def __reduce__(self):
return (type(self), (self._devices, self._hlo_sharding.to_proto()),
@ -798,12 +785,12 @@ class GSPMDSharding(XLACompatibleSharding):
return True
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self._devices == other._devices and
are_mem_kind_of_shardings_equal(self, other))
self.memory_kind == other.memory_kind)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._devices, self._hlo_sharding_hash,
get_canonicalized_memory_kind(self)))
self.memory_kind))
return self._hash
def __repr__(self):