mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Canonicalize to default memory in init of Shardings only on the backends that support memories right now.
PiperOrigin-RevId: 553942534
This commit is contained in:
parent
73d1b26cf6
commit
1ae37b4131
@ -567,6 +567,7 @@ pytype_strict_library(
|
||||
srcs = ["_src/sharding_impls.py"],
|
||||
visibility = [":internal"] + jax_visibility("sharding_impls"),
|
||||
deps = [
|
||||
":config",
|
||||
":mesh",
|
||||
":op_shardings",
|
||||
":partition_spec",
|
||||
|
@ -24,7 +24,6 @@ from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
|
||||
|
||||
import numpy as np
|
||||
@ -38,7 +37,6 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_specs
|
||||
@ -71,13 +69,6 @@ from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction, weakref_lru_cache)
|
||||
|
||||
# TODO(yashkatariya): Remove this flag after the host runtime is linked by
|
||||
# default and works on cloud TPU.
|
||||
_FETCH_MEMORY_KIND_ON_EXECUTABLE = jax_config.DEFINE_bool(
|
||||
'jax_fetch_memory_kind_on_executable',
|
||||
bool(os.getenv('JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE', '')),
|
||||
help=("If True, will allow fetching memory kinds available on executable "
|
||||
"and annotate Shardings with it."))
|
||||
|
||||
# Built in Python lists don't support weak refs but subclasses of lists do.
|
||||
class WeakRefList(list):
|
||||
@ -1700,7 +1691,7 @@ class SemanticallyEqualShardings:
|
||||
return False
|
||||
return all(
|
||||
(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
|
||||
and sharding_impls.are_mem_kind_of_shardings_equal(s, o))
|
||||
and s.memory_kind == o.memory_kind)
|
||||
if (isinstance(s, sharding_impls.GSPMDSharding) and
|
||||
isinstance(o, sharding_impls.GSPMDSharding))
|
||||
else s == o
|
||||
@ -2225,7 +2216,7 @@ def get_gspmd_shardings_from_executable(
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
from jax._src import pjit
|
||||
|
||||
if _FETCH_MEMORY_KIND_ON_EXECUTABLE.value:
|
||||
if sharding_impls._ENABLE_MEMORY_KIND.value:
|
||||
try:
|
||||
omk = xla_executable.get_output_memory_kinds()[0]
|
||||
except:
|
||||
@ -2318,10 +2309,10 @@ def _get_out_sharding_from_orig_sharding(
|
||||
# replicated then, it doesn't encode the ndim in it. The devices
|
||||
# will be the same at this point because those checks happen before.
|
||||
if (orig_aval is not None and out_aval is not None and
|
||||
out_aval.ndim == orig_aval.ndim and
|
||||
sharding_impls.are_op_shardings_equal(
|
||||
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and
|
||||
sharding_impls.are_mem_kind_of_shardings_equal(o, orig_in_s)):
|
||||
out_aval.ndim == orig_aval.ndim
|
||||
and sharding_impls.are_op_shardings_equal(
|
||||
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
|
||||
and o.memory_kind == orig_in_s.memory_kind):
|
||||
out.append((orig_in_s, False))
|
||||
else:
|
||||
out.append((orig_handler(o, orig_in_s), False))
|
||||
@ -2553,7 +2544,7 @@ class UnloadedMeshExecutable:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
|
||||
not sharding_impls.are_mem_kind_of_shardings_equal(xla_s, orig)): # type: ignore
|
||||
xla_s.memory_kind != orig.memory_kind): # type: ignore
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
@ -2873,7 +2864,7 @@ def check_gda_or_array_xla_sharding_match(
|
||||
continue
|
||||
|
||||
# Raise memory kind mismatch error even if the arg is uncommitted.
|
||||
if not sharding_impls.are_mem_kind_of_shardings_equal(arg.sharding, xs):
|
||||
if arg.sharding.memory_kind != xs.memory_kind:
|
||||
errors.append(
|
||||
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
|
||||
f"arg {name} with shape: {arg.aval.str_short()}")
|
||||
|
@ -1744,8 +1744,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
if (not op_shardings.are_op_shardings_equal(
|
||||
in_sharding._to_xla_hlo_sharding(ndim),
|
||||
xmap_sharding._to_xla_hlo_sharding(ndim)) or
|
||||
not sharding_impls.are_mem_kind_of_shardings_equal(
|
||||
in_sharding, xmap_sharding)):
|
||||
in_sharding.memory_kind != xmap_sharding.memory_kind):
|
||||
raise ValueError(
|
||||
f"Got an input {arr_flavor} to xmap with different partitioning than "
|
||||
"specified in xmap. The partitioning must match. "
|
||||
|
@ -59,8 +59,7 @@ from jax._src.sharding_impls import (
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||
AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding,
|
||||
are_mem_kind_of_shardings_equal)
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
@ -1095,7 +1094,7 @@ def _resolve_in_shardings(
|
||||
# jax.jit does not allow resharding across different memory kinds even
|
||||
# if the argument is uncommitted. Use jax.device_put for those cases,
|
||||
# either outside or inside jax.jit.
|
||||
if not are_mem_kind_of_shardings_equal(pjit_in_s, arg_s): # type: ignore
|
||||
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
|
||||
raise ValueError(
|
||||
'Memory kinds passed to jax.jit does not match memory kind on the'
|
||||
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
|
||||
@ -1242,7 +1241,7 @@ class SameDeviceAssignmentTuple:
|
||||
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding):
|
||||
eq.append(
|
||||
op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
|
||||
and are_mem_kind_of_shardings_equal(s, o))
|
||||
and s.memory_kind == o.memory_kind)
|
||||
else:
|
||||
eq.append(s == o)
|
||||
return all(eq) and self.device_assignment == other.device_assignment
|
||||
|
@ -22,6 +22,7 @@ import enum
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, NamedTuple, Union, cast, Optional
|
||||
|
||||
from jax._src import mesh as mesh_lib
|
||||
@ -30,11 +31,13 @@ from jax._src.op_shardings import (
|
||||
op_sharding_to_indices)
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
|
||||
import numpy as np
|
||||
@ -46,6 +49,15 @@ Index = tuple[slice, ...]
|
||||
XLADeviceAssignment = tuple[Device, ...]
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove this flag after the host runtime is linked by
|
||||
# default and works on cloud TPU.
|
||||
_ENABLE_MEMORY_KIND = jax_config.DEFINE_bool(
|
||||
'jax_enable_memory_kind',
|
||||
bool(os.getenv('JAX_ENABLE_MEMORY_KIND', '')),
|
||||
help=("If True, will allow fetching memory kinds available on executable "
|
||||
"and annotate Shardings with it."))
|
||||
|
||||
|
||||
# Shardings that inherit from XLACompatibleSharding should implement the
|
||||
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
|
||||
@use_cpp_class(xc.XLACompatibleSharding)
|
||||
@ -109,7 +121,7 @@ class XLACompatibleSharding(sharding.Sharding):
|
||||
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
||||
other._to_xla_hlo_sharding(ndim))
|
||||
and self._device_assignment == other._device_assignment and
|
||||
are_mem_kind_of_shardings_equal(self, other))
|
||||
self.memory_kind == other.memory_kind)
|
||||
# NotImplementedError is raised by PmapSharding because it can't lower
|
||||
# to OpSharding. So if `other` is a PmapSharding, default to a strict
|
||||
# equality check.
|
||||
@ -163,13 +175,6 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
def _mem_kinds(client: xc.Client) -> set[str]:
|
||||
return set(m.kind for m in client.local_devices()[0].addressable_memories())
|
||||
|
||||
@functools.lru_cache
|
||||
def _default_mem_kind(client: xc.Client) -> str | None:
|
||||
try:
|
||||
return client.local_devices()[0].default_memory().kind
|
||||
except:
|
||||
return None
|
||||
|
||||
def _check_mem_kind(device: xc.Device, mk):
|
||||
mem_kinds = _mem_kinds(device.client)
|
||||
if mk not in mem_kinds:
|
||||
@ -179,26 +184,6 @@ def _check_mem_kind(device: xc.Device, mk):
|
||||
f' {mem_kinds}. Got memory kind: {mk}')
|
||||
|
||||
|
||||
def get_canonicalized_memory_kind(s: XLACompatibleSharding) -> Optional[str]:
|
||||
# TODO(yashkatariya): Remove try;except when CPU and GPU support memories.
|
||||
try:
|
||||
client = s._device_assignment[0].client
|
||||
return _default_mem_kind(client) if s.memory_kind is None else s.memory_kind
|
||||
except:
|
||||
return None
|
||||
|
||||
# TODO(yashkatariya): Remove this when the canonicalization happens in __init__
|
||||
# of Shardings. That can be done after OSS support is also added for memories.
|
||||
def are_mem_kind_of_shardings_equal(s1: XLACompatibleSharding,
|
||||
s2: XLACompatibleSharding) -> bool:
|
||||
if s1.memory_kind is None and s2.memory_kind is None:
|
||||
return True
|
||||
|
||||
mk1 = get_canonicalized_memory_kind(s1)
|
||||
mk2 = get_canonicalized_memory_kind(s2)
|
||||
return mk1 == mk2
|
||||
|
||||
|
||||
@use_cpp_class(xc.NamedSharding)
|
||||
class NamedSharding(XLACompatibleSharding):
|
||||
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
||||
@ -281,8 +266,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((self.mesh, get_canonicalized_memory_kind(self),
|
||||
self._parsed_pspec))
|
||||
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -291,7 +275,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
|
||||
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
|
||||
mem_kind_equal = self.memory_kind == other.memory_kind
|
||||
if (id(self.mesh) == id(other.mesh) and mem_kind_equal and
|
||||
parsed_pspec_equal):
|
||||
return True
|
||||
@ -408,7 +392,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((self._device, get_canonicalized_memory_kind(self)))
|
||||
self._hash = hash((self._device, self.memory_kind))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -417,7 +401,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
return (self._device == other._device and
|
||||
are_mem_kind_of_shardings_equal(self, other))
|
||||
self.memory_kind == other.memory_kind)
|
||||
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
@ -628,6 +612,9 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
if self._memory_kind is not None:
|
||||
# Will error if memory_kind does not exist on the device.
|
||||
_check_mem_kind(self._devices[0], self._memory_kind)
|
||||
if xla_extension_version >= 177:
|
||||
self._memory_kind = xc.canonicalize_memory_kind(
|
||||
self._memory_kind, self._devices[0])
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@ -673,7 +660,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((self._devices, get_canonicalized_memory_kind(self)))
|
||||
self._hash = hash((self._devices, self.memory_kind))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@ -682,7 +669,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
all_ids_equal = np.array_equal(self._ids,other._ids)
|
||||
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
|
||||
mem_kind_equal = self.memory_kind == other.memory_kind
|
||||
if (id(self._devices) == id(other._devices) and mem_kind_equal and
|
||||
all_ids_equal):
|
||||
return True
|
||||
@ -779,9 +766,9 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
self._memory_kind = memory_kind
|
||||
|
||||
def _preprocess(self):
|
||||
if self.memory_kind is not None:
|
||||
if self._memory_kind is not None:
|
||||
# Will error if memory_kind does not exist on the device.
|
||||
_check_mem_kind(self._devices[0], self.memory_kind)
|
||||
_check_mem_kind(self._devices[0], self._memory_kind)
|
||||
|
||||
def __reduce__(self):
|
||||
return (type(self), (self._devices, self._hlo_sharding.to_proto()),
|
||||
@ -798,12 +785,12 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
return True
|
||||
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
||||
and self._devices == other._devices and
|
||||
are_mem_kind_of_shardings_equal(self, other))
|
||||
self.memory_kind == other.memory_kind)
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash((self._devices, self._hlo_sharding_hash,
|
||||
get_canonicalized_memory_kind(self)))
|
||||
self.memory_kind))
|
||||
return self._hash
|
||||
|
||||
def __repr__(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user