[Memories] Add Memories support to jax.jit and jax.device_put!

These are the following changes:

* Add a temporary flag (`JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE`) (should not be used by user but needed in C++ in pjrt-ifrt code) on whether to fetch memory kinds from executable. If it is set to True, the host runtime dep needs to be linked in and should also work in OSS (more work needs to happen for that). So only the test sets it to True for now until jax memories is under development.

* Add with_memory_kind method on Sharding to allow for easier creation of shardings with different memory kind.

* Add lowering rules for device_put and jax.jit.
  * For device_put, we always add the annotation that describes a transfer to a memory and a sharding annotation.
  * For jax.jit, if the argument is on host memory, it will have an extra attribute _xla_buffer_placement.

* Handle the correct output sharding in pxla.py by extracting the memory kind from the executable.

* Handle the caching of pjit caches by canonicalizing the memory_kinds so that `NS(mesh, pspec) == NS(mesh, pspec, memory_kind='tpu_hbm')`. Also canonicalize memory_kind in `__hash__` and `__eq__` of shardings.
  * This is to not change the StableHLO to include device placement annotations right now since the host aware passes are not enabled by default and the work is under progress to make it work everywhere.

PiperOrigin-RevId: 553833344
This commit is contained in:
Yash Katariya 2023-08-04 09:43:39 -07:00 committed by jax authors
parent d7940ee9a1
commit 4fb8cdb019
7 changed files with 287 additions and 93 deletions

View File

@ -345,6 +345,12 @@ def jaxpr_shardings(
return PartitionSpec(*(names.get(i) for i in range(ndmin))) return PartitionSpec(*(names.get(i) for i in range(ndmin)))
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
for names in [*eqn.params['in_names'], *eqn.params['out_names']]) for names in [*eqn.params['in_names'], *eqn.params['out_names']])
elif eqn.primitive is device_put_p:
s = eqn.params['device']
if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None:
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
yield (s, source_info)
for subjaxpr in core.subjaxprs(jaxpr): for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_shardings(subjaxpr) yield from jaxpr_shardings(subjaxpr)
@ -699,5 +705,12 @@ ad.deflinear2(device_put_p, device_put_transpose_rule)
batching.defvectorized(device_put_p) batching.defvectorized(device_put_p)
def _device_put_lowering(ctx, x, *, device, src): def _device_put_lowering(ctx, x, *, device, src):
if isinstance(device, XLACompatibleSharding) and device.memory_kind is not None:
aval, = ctx.avals_in
out_aval, = ctx.avals_out
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
return [x]
return [x] return [x]
mlir.register_lowering(device_put_p, _device_put_lowering) mlir.register_lowering(device_put_p, _device_put_lowering)

View File

@ -25,7 +25,7 @@ import itertools
import operator import operator
import re import re
import typing import typing
from typing import (Any, Callable, NamedTuple, Protocol, Union) from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
import warnings import warnings
from jax._src import ad_util from jax._src import ad_util
@ -643,6 +643,12 @@ def _to_logical_op_sharding(
assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
return sharding._to_xla_hlo_sharding(aval.ndim) return sharding._to_xla_hlo_sharding(aval.ndim)
def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
if s is None:
return None
assert isinstance(s, sharding_impls.XLACompatibleSharding)
return s.memory_kind
def lower_jaxpr_to_module( def lower_jaxpr_to_module(
module_name: str, module_name: str,
@ -712,6 +718,11 @@ def lower_jaxpr_to_module(
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings) map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings) if result_shardings is not None else result_shardings)
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
if arg_shardings is not None else None)
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack, ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks, keepalives, channel_iter, host_callbacks,
override_lowering_rules=override_lowering_rules, override_lowering_rules=override_lowering_rules,
@ -733,7 +744,9 @@ def lower_jaxpr_to_module(
result_shardings=result_op_shardings, result_shardings=result_op_shardings,
input_output_aliases=input_output_aliases, input_output_aliases=input_output_aliases,
arg_names=arg_names, arg_names=arg_names,
result_names=result_names) result_names=result_names,
arg_memory_kinds=arg_memory_kinds,
result_memory_kinds=result_memory_kinds)
try: try:
if not ctx.module.operation.verify(): if not ctx.module.operation.verify():
@ -857,6 +870,8 @@ def lower_jaxpr_to_fun(
api_name: str = "jit", api_name: str = "jit",
arg_names: Sequence[str | None] | None = None, arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None, result_names: Sequence[str | None] | None = None,
arg_memory_kinds: Sequence[str | None] | None = None,
result_memory_kinds: Sequence[str | None] | None = None,
) -> func_dialect.FuncOp: ) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function. """Lowers jaxpr and its callees to an IR function.
@ -917,13 +932,15 @@ def lower_jaxpr_to_fun(
input_types = [*dim_var_types, *token_types, *input_types] input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types] output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None: if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens) token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
input_output_aliases = [*token_input_output_aliases, *input_output_aliases] input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
# Update the existing aliases to account for the new output values # Update the existing aliases to account for the new output values
input_output_aliases = [None if a is None input_output_aliases = [None if a is None
else a + num_output_tokens + num_tokens else a + num_output_tokens + num_tokens
for a in input_output_aliases] for a in input_output_aliases] # type: ignore
if arg_shardings is not None: if arg_shardings is not None:
token_shardings = [None] * (num_dim_vars + num_tokens) token_shardings = [None] * (num_dim_vars + num_tokens)
arg_shardings = [*token_shardings, *arg_shardings] arg_shardings = [*token_shardings, *arg_shardings]
@ -933,6 +950,13 @@ def lower_jaxpr_to_fun(
if replicated_args is not None: if replicated_args is not None:
token_replicated_args = [False] * (num_dim_vars + num_tokens) token_replicated_args = [False] * (num_dim_vars + num_tokens)
replicated_args = [*token_replicated_args, *replicated_args] replicated_args = [*token_replicated_args, *replicated_args]
if arg_memory_kinds is not None:
token_memory_kinds = [None] * (num_dim_vars + num_tokens)
arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds]
if result_memory_kinds is not None:
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
flat_input_types = util.flatten(input_types) flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types) flat_output_types = util.flatten(output_types)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
@ -940,6 +964,7 @@ def lower_jaxpr_to_fun(
func_op.attributes["sym_visibility"] = ir.StringAttr.get( func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private") "public" if public else "private")
ctx.symbol_table.insert(func_op) ctx.symbol_table.insert(func_op)
ir_arg_shardings = None ir_arg_shardings = None
if arg_shardings is not None: if arg_shardings is not None:
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals) in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
@ -947,6 +972,12 @@ def lower_jaxpr_to_fun(
[[_to_physical_op_sharding(a, s)] * len(types) [[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(in_avals, arg_shardings, input_types)]) for a, s, types in zip(in_avals, arg_shardings, input_types)])
del in_avals del in_avals
ir_arg_memory_kinds = None
if arg_memory_kinds is not None:
ir_arg_memory_kinds = util.flatten(
[[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)])
ir_result_shardings = None ir_result_shardings = None
if result_shardings is not None: if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals) out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
@ -955,6 +986,11 @@ def lower_jaxpr_to_fun(
for a, s, types in zip(out_avals, result_shardings, output_types)]) for a, s, types in zip(out_avals, result_shardings, output_types)])
del out_avals del out_avals
ir_result_memory_kinds = None
if result_memory_kinds is not None:
ir_result_memory_kinds = util.flatten(
[[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)])
if ( if (
replicated_args is not None replicated_args is not None
or ir_arg_shardings is not None or ir_arg_shardings is not None
@ -1043,7 +1079,13 @@ def lower_jaxpr_to_fun(
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s) a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)] for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)), if ir_arg_memory_kinds is not None:
flat_args = [
a if mk is None else wrap_with_memory_kind(a, mk, a_aval, is_input=True)
for a, mk, a_aval in zip(flat_args, ir_arg_memory_kinds, input_avals)]
_, token_args, unflattened_args = util.split_list(
util.unflatten(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens]) [num_dim_vars, num_tokens])
if create_tokens: if create_tokens:
tokens_in = TokenSet.create(effects) tokens_in = TokenSet.create(effects)
@ -1079,11 +1121,42 @@ def lower_jaxpr_to_fun(
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s) o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)] for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
if ir_result_memory_kinds is not None:
flat_outputs = [
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
func_dialect.ReturnOp(flat_outputs) func_dialect.ReturnOp(flat_outputs)
return func_op return func_op
def get_compute_type(memory_kind: str) -> str:
if memory_kind == 'tpu_hbm':
return 'dense'
elif memory_kind == 'unpinned_host':
return 'host'
raise ValueError(f'Unknown memory_kind: {memory_kind}')
def wrap_with_memory_kind(
x: ir.Value, memory_kind: str, aval_out: core.AbstractValue,
is_input: bool = False) -> ir.Value:
if aval_out is None:
result_type = x.type
else:
result_type = aval_to_ir_type(aval_out)
op = custom_call("annotate_device_placement", [result_type], [x],
has_side_effect=False,
api_version=1)
mka = get_compute_type(memory_kind)
dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)}
if is_input and mka == 'host':
dict_attr.update({"_xla_buffer_placement": ir.StringAttr.get("arg")})
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
return op.result
def _to_physical_op_sharding( def _to_physical_op_sharding(
aval: core.AbstractValue | None, sharding: xc.HloSharding | None aval: core.AbstractValue | None, sharding: xc.HloSharding | None
) -> xc.OpSharding | None: ) -> xc.OpSharding | None:

View File

@ -24,6 +24,7 @@ from functools import partial, lru_cache, cached_property
import itertools as it import itertools as it
import logging import logging
import math import math
import os
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar) from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
import numpy as np import numpy as np
@ -37,6 +38,7 @@ from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import effects from jax._src import effects
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import config as jax_config
from jax._src import mesh as mesh_lib from jax._src import mesh as mesh_lib
from jax._src import op_shardings from jax._src import op_shardings
from jax._src import sharding_specs from jax._src import sharding_specs
@ -69,6 +71,13 @@ from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log, wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache) unzip2, HashableFunction, weakref_lru_cache)
# TODO(yashkatariya): Remove this flag after the host runtime is linked by
# default and works on cloud TPU.
_FETCH_MEMORY_KIND_ON_EXECUTABLE = jax_config.DEFINE_bool(
'jax_fetch_memory_kind_on_executable',
bool(os.getenv('JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE', '')),
help=("If True, will allow fetching memory kinds available on executable "
"and annotate Shardings with it."))
# Built in Python lists don't support weak refs but subclasses of lists do. # Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list): class WeakRefList(list):
@ -1689,10 +1698,14 @@ class SemanticallyEqualShardings:
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, SemanticallyEqualShardings): if not isinstance(other, SemanticallyEqualShardings):
return False return False
return all(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding) return all(
if (isinstance(s, sharding_impls.GSPMDSharding) and (op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
isinstance(o, sharding_impls.GSPMDSharding)) and sharding_impls.are_mem_kind_of_shardings_equal(s, o))
else s == o for s, o in zip(self.shardings, other.shardings)) if (isinstance(s, sharding_impls.GSPMDSharding) and
isinstance(o, sharding_impls.GSPMDSharding))
else s == o
for s, o in zip(self.shardings, other.shardings)
)
@weakref_lru_cache @weakref_lru_cache
@ -2208,34 +2221,40 @@ def _get_input_indices(
def get_gspmd_shardings_from_executable( def get_gspmd_shardings_from_executable(
xla_executable, device_assignment: Sequence[xc.Device], xla_executable, device_assignment: Sequence[xc.Device],
num_in_avals: int, num_out_avals: int num_out_avals: int
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding], ) -> Sequence[sharding_impls.XLACompatibleSharding]:
Sequence[sharding_impls.XLACompatibleSharding]]:
from jax._src import pjit from jax._src import pjit
if _FETCH_MEMORY_KIND_ON_EXECUTABLE.value:
try:
omk = xla_executable.get_output_memory_kinds()[0]
except:
omk = [None] * num_out_avals
else:
omk = [None] * num_out_avals
# When the device assignment only has 1 device, SPMD partitioner will not run. # When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case, # Hence the op shardings will not be set on the `hlo_module`. In that case,
# just return SingleDeviceShardings since we know the computation is running # just return SingleDeviceShardings since we know the computation is running
# only on 1 device. # only on 1 device.
if len(device_assignment) == 1: if len(device_assignment) == 1:
ss = sharding_impls.SingleDeviceSharding(device_assignment[0]) return [sharding_impls.SingleDeviceSharding(device_assignment[0], memory_kind=mk)
return [ss] * num_in_avals, [ss] * num_out_avals for mk in omk]
in_op_shardings, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i)
for i in in_op_shardings]
out_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, o)
for o in out_op_shardings]
# This condition happens when all the elements in the output tuple have the # This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to # same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple. # put the sharding on ROOT instead of the tuple.
# TODO(b/245667823): Remove this when XLA fixes this. # TODO(b/245667823): Remove this when XLA fixes this.
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals: if len(out_op_shardings) == 1 and len(out_op_shardings) < num_out_avals:
out_shardings_xla = out_shardings_xla * num_out_avals out_op_shardings = out_op_shardings * num_out_avals # type: ignore
assert len(out_shardings_xla) == num_out_avals, (
len(out_shardings_xla), num_out_avals) assert len(out_op_shardings) == num_out_avals == len(omk), (
return in_shardings_xla, out_shardings_xla len(out_op_shardings), num_out_avals, len(omk))
return [sharding_impls.GSPMDSharding(device_assignment, os, memory_kind=mk)
for os, mk in safe_zip(out_op_shardings, omk)]
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings # TODO(yashkatariya): Remove this function after `AUTO` can return shardings
@ -2258,39 +2277,38 @@ _ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding)
def _register_out_sharding_handler( def _register_out_sharding_handler(
sharding_cls: type[_ShardingT], sharding_cls: type[_ShardingT],
handler: Callable[[xc.OpSharding, _ShardingT], _ShardingT], handler: Callable[[sharding_impls.GSPMDSharding, _ShardingT], _ShardingT],
) -> None: ) -> None:
_orig_out_sharding_handlers[sharding_cls] = handler _orig_out_sharding_handlers[sharding_cls] = handler
def _gspmd_to_named_sharding( def _gspmd_to_named_sharding(
op_sharding: xc.OpSharding, out_s: sharding_impls.GSPMDSharding,
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
parsed_pspec = sharding_impls.parse_flatten_op_sharding( parsed_pspec = sharding_impls.parse_flatten_op_sharding(
op_sharding, self.mesh)[0] out_s._hlo_sharding, orig_in_s.mesh)[0]
return create_mesh_pspec_sharding( return create_mesh_pspec_sharding(
self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec) orig_in_s.mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)
_register_out_sharding_handler( _register_out_sharding_handler(
sharding_impls.NamedSharding, _gspmd_to_named_sharding sharding_impls.NamedSharding, _gspmd_to_named_sharding)
)
def _gspmd_to_positional_sharding( def _gspmd_to_positional_sharding(
op_sharding: xc.OpSharding, out_s: sharding_impls.GSPMDSharding,
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding: orig_in_s: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
return sharding_impls._op_sharding_to_pos_sharding( return sharding_impls._op_sharding_to_pos_sharding(
op_sharding, self._device_assignment) out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
_register_out_sharding_handler( _register_out_sharding_handler(
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding sharding_impls.PositionalSharding, _gspmd_to_positional_sharding)
)
def _get_out_sharding_from_orig_sharding( def _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_s, orig_aval, are_out_sharding_from_xla): out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla):
out = [] out = []
orig_handler = _orig_out_sharding_handlers[type(orig_s)] orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals, for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
are_out_sharding_from_xla): are_out_sharding_from_xla):
if isinstance(o, sharding_impls.GSPMDSharding): if isinstance(o, sharding_impls.GSPMDSharding):
@ -2302,10 +2320,11 @@ def _get_out_sharding_from_orig_sharding(
if (orig_aval is not None and out_aval is not None and if (orig_aval is not None and out_aval is not None and
out_aval.ndim == orig_aval.ndim and out_aval.ndim == orig_aval.ndim and
sharding_impls.are_op_shardings_equal( sharding_impls.are_op_shardings_equal(
o._hlo_sharding, orig_s._to_xla_hlo_sharding(orig_aval.ndim))): o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and
out.append((orig_s, False)) sharding_impls.are_mem_kind_of_shardings_equal(o, orig_in_s)):
out.append((orig_in_s, False))
else: else:
out.append((orig_handler(o._hlo_sharding, orig_s), False)) out.append((orig_handler(o, orig_in_s), False))
except: except:
out.append((o, from_xla)) out.append((o, from_xla))
else: else:
@ -2319,17 +2338,17 @@ def maybe_get_orig_out_sharding(
return ([o._original_sharding for o in out_shardings], return ([o._original_sharding for o in out_shardings],
(False,) * len(out_shardings)) (False,) * len(out_shardings))
orig_s = None orig_in_s = None
orig_aval = None orig_aval = None
for i, aval in safe_zip(in_shardings, in_avals): for i, aval in safe_zip(in_shardings, in_avals):
oi = getattr(i, '_original_sharding', None) oi = getattr(i, '_original_sharding', None)
if type(oi) in _orig_out_sharding_handlers: if type(oi) in _orig_out_sharding_handlers:
orig_s = oi orig_in_s = oi
orig_aval = aval orig_aval = aval
break break
if orig_s is not None: if orig_in_s is not None:
return zip(*_get_out_sharding_from_orig_sharding( return zip(*_get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_s, orig_aval, are_out_shardings_from_xla)) out_shardings, out_avals, orig_in_s, orig_aval, are_out_shardings_from_xla))
return out_shardings, are_out_shardings_from_xla return out_shardings, are_out_shardings_from_xla
@ -2521,9 +2540,8 @@ class UnloadedMeshExecutable:
assert mesh is None assert mesh is None
device_assignment = da.device_assignment if isinstance( # type: ignore device_assignment = da.device_assignment if isinstance( # type: ignore
da, _DeviceAssignment) else da da, _DeviceAssignment) else da
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, # type: ignore xla_executable, device_assignment, len(global_out_avals)) # type: ignore
len(global_in_avals), len(global_out_avals))
orig_out_shardings = out_shardings orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings, for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
@ -2532,9 +2550,10 @@ class UnloadedMeshExecutable:
out_shardings.append(xla_s) out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True) are_out_shardings_from_xla.append(True)
else: else:
if not op_shardings.are_op_shardings_equal( xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
xla_s._to_xla_hlo_sharding(aval.ndim), # type: ignore orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig._to_xla_hlo_sharding(aval.ndim)): # type: ignore if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
not sharding_impls.are_mem_kind_of_shardings_equal(xla_s, orig)): # type: ignore
raise AssertionError( raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} " f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)") "(User sharding)")
@ -2823,11 +2842,12 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
@lru_cache @lru_cache
def create_mesh_pspec_sharding( def create_mesh_pspec_sharding(
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
) -> sharding_impls.NamedSharding: memory_kind: Optional[str] = None) -> sharding_impls.NamedSharding:
if pspec is None: if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None pspec, parsed_pspec = PartitionSpec(), None
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec) return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
memory_kind=memory_kind)
def check_device_backend_on_shardings(shardings) -> bool: def check_device_backend_on_shardings(shardings) -> bool:
@ -2852,6 +2872,12 @@ def check_gda_or_array_xla_sharding_match(
if not isinstance(arg, ArrayImpl): if not isinstance(arg, ArrayImpl):
continue continue
# Raise memory kind mismatch error even if the arg is uncommitted.
if not sharding_impls.are_mem_kind_of_shardings_equal(arg.sharding, xs):
errors.append(
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
f"arg {name} with shape: {arg.aval.str_short()}")
# No need to cache this check since MeshExecutable has a C++ fast path # No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call. # for AOT compiled call.
if (not check_device_backend_on_shardings([xs]) and if (not check_device_backend_on_shardings([xs]) and

View File

@ -1741,9 +1741,11 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
args_flat): args_flat):
@lru_cache @lru_cache
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor): def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
if not op_shardings.are_op_shardings_equal( if (not op_shardings.are_op_shardings_equal(
in_sharding._to_xla_hlo_sharding(ndim), in_sharding._to_xla_hlo_sharding(ndim),
xmap_sharding._to_xla_hlo_sharding(ndim)): xmap_sharding._to_xla_hlo_sharding(ndim)) or
not sharding_impls.are_mem_kind_of_shardings_equal(
in_sharding, xmap_sharding)):
raise ValueError( raise ValueError(
f"Got an input {arr_flavor} to xmap with different partitioning than " f"Got an input {arr_flavor} to xmap with different partitioning than "
"specified in xmap. The partitioning must match. " "specified in xmap. The partitioning must match. "

View File

@ -59,7 +59,8 @@ from jax._src.sharding_impls import (
XLADeviceAssignment, SingleDeviceSharding, PmapSharding, XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
AUTO, UNSPECIFIED, UnspecifiedValue, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified, ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding,
are_mem_kind_of_shardings_equal)
from jax._src.traceback_util import api_boundary from jax._src.traceback_util import api_boundary
from jax._src.tree_util import ( from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
@ -261,7 +262,6 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
donate_argnums, tree_util.default_registry, # type: ignore donate_argnums, tree_util.default_registry, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun cpp_pjitted_f._fun = fun
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
@ -1092,6 +1092,14 @@ def _resolve_in_shardings(
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
f'Got arg shape: {arg.shape}, arg value: {arg}') f'Got arg shape: {arg.shape}, arg value: {arg}')
if not is_unspecified(arg_s): if not is_unspecified(arg_s):
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
if not are_mem_kind_of_shardings_equal(pjit_in_s, arg_s): # type: ignore
raise ValueError(
'Memory kinds passed to jax.jit does not match memory kind on the'
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
f'arg memory kind: {arg_s.memory_kind} for arg shape: {arg.shape}') # type: ignore
if (committed and if (committed and
not isinstance(arg_s, PmapSharding) and not isinstance(arg_s, PmapSharding) and
not op_shardings.are_op_shardings_equal( not op_shardings.are_op_shardings_equal(
@ -1232,8 +1240,9 @@ class SameDeviceAssignmentTuple:
s = getattr(s, "_original_sharding", s) s = getattr(s, "_original_sharding", s)
o = getattr(o, "_original_sharding", o) o = getattr(o, "_original_sharding", o)
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding): if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding):
eq.append(op_shardings.are_op_shardings_equal( eq.append(
s._hlo_sharding, o._hlo_sharding)) op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
and are_mem_kind_of_shardings_equal(s, o))
else: else:
eq.append(s == o) eq.append(s == o)
return all(eq) and self.device_assignment == other.device_assignment return all(eq) and self.device_assignment == other.device_assignment
@ -1944,7 +1953,8 @@ def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int,
device_or_backend_set: bool = False) -> GSPMDSharding: device_or_backend_set: bool = False) -> GSPMDSharding:
if isinstance(s, GSPMDSharding): if isinstance(s, GSPMDSharding):
return s return s
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim)) gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
memory_kind=s.memory_kind)
gs._original_sharding = s gs._original_sharding = s
if device_or_backend_set: if device_or_backend_set:
gs._original_sharding._device_backend = device_or_backend_set gs._original_sharding._device_backend = device_or_backend_set

View File

@ -93,6 +93,10 @@ class Sharding:
"""Returns the memory kind of the sharding.""" """Returns the memory kind of the sharding."""
raise NotImplementedError('Subclasses should implement this method.') raise NotImplementedError('Subclasses should implement this method.')
def with_memory_kind(self, kind: str) -> Sharding:
"""Returns a new Sharding instance with the specified memory kind."""
raise NotImplementedError('Subclasses should implement this method')
############################################################################# #############################################################################
# Default implementations below that all subclasses will inherit. # Default implementations below that all subclasses will inherit.

View File

@ -22,7 +22,7 @@ import enum
import functools import functools
import itertools import itertools
import math import math
from typing import Any, NamedTuple, Union, cast from typing import Any, NamedTuple, Union, cast, Optional
from jax._src import mesh as mesh_lib from jax._src import mesh as mesh_lib
from jax._src.op_shardings import ( from jax._src.op_shardings import (
@ -109,7 +109,7 @@ class XLACompatibleSharding(sharding.Sharding):
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
other._to_xla_hlo_sharding(ndim)) other._to_xla_hlo_sharding(ndim))
and self._device_assignment == other._device_assignment and and self._device_assignment == other._device_assignment and
self.memory_kind == other.memory_kind) are_mem_kind_of_shardings_equal(self, other))
# NotImplementedError is raised by PmapSharding because it can't lower # NotImplementedError is raised by PmapSharding because it can't lower
# to OpSharding. So if `other` is a PmapSharding, default to a strict # to OpSharding. So if `other` is a PmapSharding, default to a strict
# equality check. # equality check.
@ -154,6 +154,51 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out return out
# This is an optimization to get the memory kinds associated with the local
# devices. This is because in McJAX, checking if the memory kind input by user
# is correct requires doing `local_devices()[0].memory(inp)` which is expensive
# because calculating the local devices is expensive. So cache on xc.Client and
# find all the memories associated only once since the client does not change.
@functools.lru_cache
def _mem_kinds(client: xc.Client) -> set[str]:
return set(m.kind for m in client.local_devices()[0].addressable_memories())
@functools.lru_cache
def _default_mem_kind(client: xc.Client) -> str | None:
try:
return client.local_devices()[0].default_memory().kind
except:
return None
def _check_mem_kind(device: xc.Device, mk):
mem_kinds = _mem_kinds(device.client)
if mk not in mem_kinds:
raise ValueError(
f'Could not find memory addressable by device {device.device_kind}.'
f' Device {device.device_kind} can address the following memory kinds:'
f' {mem_kinds}. Got memory kind: {mk}')
def get_canonicalized_memory_kind(s: XLACompatibleSharding) -> Optional[str]:
# TODO(yashkatariya): Remove try;except when CPU and GPU support memories.
try:
client = s._device_assignment[0].client
return _default_mem_kind(client) if s.memory_kind is None else s.memory_kind
except:
return None
# TODO(yashkatariya): Remove this when the canonicalization happens in __init__
# of Shardings. That can be done after OSS support is also added for memories.
def are_mem_kind_of_shardings_equal(s1: XLACompatibleSharding,
s2: XLACompatibleSharding) -> bool:
if s1.memory_kind is None and s2.memory_kind is None:
return True
mk1 = get_canonicalized_memory_kind(s1)
mk2 = get_canonicalized_memory_kind(s2)
return mk1 == mk2
@use_cpp_class(xc.NamedSharding) @use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding): class NamedSharding(XLACompatibleSharding):
r"""A :class:`NamedSharding` expresses sharding using named axes. r"""A :class:`NamedSharding` expresses sharding using named axes.
@ -207,16 +252,13 @@ class NamedSharding(XLACompatibleSharding):
self._preprocess() self._preprocess()
def __reduce__(self): def __reduce__(self):
return ( return (type(self), (self.mesh, self.spec),
type(self), {'memory_kind': self.memory_kind})
(self.mesh, self.spec),
{'memory_kind': self.memory_kind},
)
def _preprocess(self): def _preprocess(self):
if self.memory_kind is not None: if self.memory_kind is not None:
# Will error if memory_kind does not exist on the device. # Will error if memory_kind does not exist on the device.
self.mesh.devices.flat[0].memory(self.memory_kind) _check_mem_kind(self.mesh.devices.flat[0], self.memory_kind)
# This split exists because you can pass `_parsed_pspec` that has been # This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to # modified from the original. For example: Adding extra dimension to
@ -239,7 +281,8 @@ class NamedSharding(XLACompatibleSharding):
def __hash__(self): def __hash__(self):
if not hasattr(self, '_hash'): if not hasattr(self, '_hash'):
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec)) self._hash = hash((self.mesh, get_canonicalized_memory_kind(self),
self._parsed_pspec))
return self._hash return self._hash
def __eq__(self, other): def __eq__(self, other):
@ -248,11 +291,11 @@ class NamedSharding(XLACompatibleSharding):
if id(self) == id(other): if id(self) == id(other):
return True return True
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
if (id(self.mesh) == id(other.mesh) and mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
self.memory_kind == other.memory_kind and parsed_pspec_equal): if (id(self.mesh) == id(other.mesh) and mem_kind_equal and
parsed_pspec_equal):
return True return True
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind return self.mesh == other.mesh and mem_kind_equal and parsed_pspec_equal
and parsed_pspec_equal)
def is_compatible_aval(self, aval_shape: Shape): def is_compatible_aval(self, aval_shape: Shape):
assert self._parsed_pspec is not None assert self._parsed_pspec is not None
@ -267,7 +310,7 @@ class NamedSharding(XLACompatibleSharding):
@classmethod @classmethod
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None): def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
return cls(mesh, parsed_pspec.get_partition_spec(), return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec) memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
@property @property
def device_set(self) -> set[Device]: def device_set(self) -> set[Device]:
@ -300,6 +343,9 @@ class NamedSharding(XLACompatibleSharding):
num_partitions *= mesh_shape[name] num_partitions *= mesh_shape[name]
return num_partitions == 1 return num_partitions == 1
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
def _get_sharding_spec(self, num_dimensions, axis_ctx): def _get_sharding_spec(self, num_dimensions, axis_ctx):
assert self._parsed_pspec is not None assert self._parsed_pspec is not None
array_mapping = get_array_mapping(self._parsed_pspec) array_mapping = get_array_mapping(self._parsed_pspec)
@ -362,7 +408,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
def __hash__(self): def __hash__(self):
if not hasattr(self, '_hash'): if not hasattr(self, '_hash'):
self._hash = hash((self._device, self._memory_kind)) self._hash = hash((self._device, get_canonicalized_memory_kind(self)))
return self._hash return self._hash
def __eq__(self, other): def __eq__(self, other):
@ -371,7 +417,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
if id(self) == id(other): if id(self) == id(other):
return True return True
return (self._device == other._device and return (self._device == other._device and
self._memory_kind == other._memory_kind) are_mem_kind_of_shardings_equal(self, other))
@property @property
def device_set(self) -> set[Device]: def device_set(self) -> set[Device]:
@ -381,6 +427,9 @@ class SingleDeviceSharding(XLACompatibleSharding):
def memory_kind(self) -> str | None: def memory_kind(self) -> str | None:
return self._memory_kind return self._memory_kind
def with_memory_kind(self, kind: str) -> SingleDeviceSharding:
return SingleDeviceSharding(self._device, memory_kind=kind)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
return {self._device: (slice(None),) * len(global_shape)} return {self._device: (slice(None),) * len(global_shape)}
@ -493,6 +542,9 @@ class PmapSharding(XLACompatibleSharding):
def memory_kind(self): def memory_kind(self):
return None return None
def with_memory_kind(self, kind: str):
return NotImplementedError("pmap does not support memories.")
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
raise NotImplementedError("pmap doesn't use OpSharding.") raise NotImplementedError("pmap doesn't use OpSharding.")
@ -524,13 +576,15 @@ class PmapSharding(XLACompatibleSharding):
def _op_sharding_to_pos_sharding( def _op_sharding_to_pos_sharding(
op_sharding: xc.OpSharding | xc.HloSharding, op_sharding: Union[xc.OpSharding, xc.HloSharding],
device_assignment: Sequence[xc.Device]) -> PositionalSharding: device_assignment: Sequence[xc.Device],
memory_kind: Optional[str] = None) -> PositionalSharding:
if isinstance(op_sharding, xc.HloSharding): if isinstance(op_sharding, xc.HloSharding):
op_sharding = op_sharding.to_proto() # type: ignore op_sharding = op_sharding.to_proto() # type: ignore
if op_sharding.type == xc.OpSharding.Type.REPLICATED: if op_sharding.type == xc.OpSharding.Type.REPLICATED:
return PositionalSharding(device_assignment).replicate() return PositionalSharding(
device_assignment, memory_kind=memory_kind).replicate()
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]: if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
replicate_on_last_tile_dim = True replicate_on_last_tile_dim = True
@ -543,7 +597,11 @@ def _op_sharding_to_pos_sharding(
name = device_assignment[0].platform.upper() name = device_assignment[0].platform.upper()
ids = np.array([DeviceIdSet(name, i) ids = np.array([DeviceIdSet(name, i)
for i in op_sharding.tile_assignment_devices]) for i in op_sharding.tile_assignment_devices])
p = PositionalSharding._remake(tuple(device_assignment), ids) if memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(device_assignment[0], memory_kind)
p = PositionalSharding._remake(tuple(device_assignment), ids,
memory_kind=memory_kind)
p = p.reshape(op_sharding.tile_assignment_dimensions) p = p.reshape(op_sharding.tile_assignment_dimensions)
if replicate_on_last_tile_dim: if replicate_on_last_tile_dim:
p = p.replicate(-1, keepdims=False) p = p.replicate(-1, keepdims=False)
@ -569,7 +627,7 @@ class PositionalSharding(XLACompatibleSharding):
dtype='object').reshape(devices.shape) dtype='object').reshape(devices.shape)
if self._memory_kind is not None: if self._memory_kind is not None:
# Will error if memory_kind does not exist on the device. # Will error if memory_kind does not exist on the device.
self._devices[0].memory(self._memory_kind) _check_mem_kind(self._devices[0], self._memory_kind)
@property @property
def shape(self): def shape(self):
@ -615,7 +673,7 @@ class PositionalSharding(XLACompatibleSharding):
def __hash__(self) -> int: def __hash__(self) -> int:
if not hasattr(self, '_hash'): if not hasattr(self, '_hash'):
self._hash = hash((self._devices, self._memory_kind)) self._hash = hash((self._devices, get_canonicalized_memory_kind(self)))
return self._hash return self._hash
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
@ -624,11 +682,11 @@ class PositionalSharding(XLACompatibleSharding):
if id(self) == id(other): if id(self) == id(other):
return True return True
all_ids_equal = np.array_equal(self._ids,other._ids) all_ids_equal = np.array_equal(self._ids,other._ids)
if (id(self._devices) == id(other._devices) and mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
self._memory_kind == other._memory_kind and all_ids_equal): if (id(self._devices) == id(other._devices) and mem_kind_equal and
all_ids_equal):
return True return True
return (self._devices == other._devices and return self._devices == other._devices and mem_kind_equal and all_ids_equal
self._memory_kind == other._memory_kind and all_ids_equal)
# Sharding interface # Sharding interface
@ -640,6 +698,9 @@ class PositionalSharding(XLACompatibleSharding):
def memory_kind(self) -> str | None: def memory_kind(self) -> str | None:
return self._memory_kind return self._memory_kind
def with_memory_kind(self, kind: str) -> PositionalSharding:
return PositionalSharding(self._devices, memory_kind=kind)
@functools.cached_property @functools.cached_property
def is_fully_replicated(self) -> bool: def is_fully_replicated(self) -> bool:
return self.shape == (1,) * self.ndim return self.shape == (1,) * self.ndim
@ -717,12 +778,14 @@ class GSPMDSharding(XLACompatibleSharding):
self._hlo_sharding = op_sharding self._hlo_sharding = op_sharding
self._memory_kind = memory_kind self._memory_kind = memory_kind
def _preprocess(self):
if self.memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(self._devices[0], self.memory_kind)
def __reduce__(self): def __reduce__(self):
return ( return (type(self), (self._devices, self._hlo_sharding.to_proto()),
type(self), {'memory_kind': self._memory_kind})
(self._devices, self._hlo_sharding.to_proto()),
{'memory_kind': self._memory_kind},
)
@functools.cached_property @functools.cached_property
def _hlo_sharding_hash(self): def _hlo_sharding_hash(self):
@ -735,12 +798,12 @@ class GSPMDSharding(XLACompatibleSharding):
return True return True
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding) return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self._devices == other._devices and and self._devices == other._devices and
self._memory_kind == other._memory_kind) are_mem_kind_of_shardings_equal(self, other))
def __hash__(self): def __hash__(self):
if not hasattr(self, '_hash'): if not hasattr(self, '_hash'):
self._hash = hash((self._devices, self._hlo_sharding_hash, self._hash = hash((self._devices, self._hlo_sharding_hash,
self._memory_kind)) get_canonicalized_memory_kind(self)))
return self._hash return self._hash
def __repr__(self): def __repr__(self):
@ -763,6 +826,9 @@ class GSPMDSharding(XLACompatibleSharding):
def memory_kind(self) -> str | None: def memory_kind(self) -> str | None:
return self._memory_kind return self._memory_kind
def with_memory_kind(self, kind: str) -> GSPMDSharding:
return GSPMDSharding(self._devices, self._hlo_sharding, memory_kind=kind)
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
self.shard_shape(global_shape) # raises a good error message self.shard_shape(global_shape) # raises a good error message