[Memories] Add Memories support to jax.jit and jax.device_put!

These are the following changes:

* Add a temporary flag (`JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE`) (should not be used by user but needed in C++ in pjrt-ifrt code) on whether to fetch memory kinds from executable. If it is set to True, the host runtime dep needs to be linked in and should also work in OSS (more work needs to happen for that). So only the test sets it to True for now until jax memories is under development.

* Add with_memory_kind method on Sharding to allow for easier creation of shardings with different memory kind.

* Add lowering rules for device_put and jax.jit.
  * For device_put, we always add the annotation that describes a transfer to a memory and a sharding annotation.
  * For jax.jit, if the argument is on host memory, it will have an extra attribute _xla_buffer_placement.

* Handle the correct output sharding in pxla.py by extracting the memory kind from the executable.

* Handle the caching of pjit caches by canonicalizing the memory_kinds so that `NS(mesh, pspec) == NS(mesh, pspec, memory_kind='tpu_hbm')`. Also canonicalize memory_kind in `__hash__` and `__eq__` of shardings.
  * This is to not change the StableHLO to include device placement annotations right now since the host aware passes are not enabled by default and the work is under progress to make it work everywhere.

PiperOrigin-RevId: 553833344
This commit is contained in:
Yash Katariya 2023-08-04 09:43:39 -07:00 committed by jax authors
parent d7940ee9a1
commit 4fb8cdb019
7 changed files with 287 additions and 93 deletions

View File

@ -345,6 +345,12 @@ def jaxpr_shardings(
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
elif eqn.primitive is device_put_p:
s = eqn.params['device']
if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None:
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
eqn.primitive.name)
yield (s, source_info)
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_shardings(subjaxpr)
@ -699,5 +705,12 @@ ad.deflinear2(device_put_p, device_put_transpose_rule)
batching.defvectorized(device_put_p)
def _device_put_lowering(ctx, x, *, device, src):
if isinstance(device, XLACompatibleSharding) and device.memory_kind is not None:
aval, = ctx.avals_in
out_aval, = ctx.avals_out
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
return [x]
return [x]
mlir.register_lowering(device_put_p, _device_put_lowering)

View File

@ -25,7 +25,7 @@ import itertools
import operator
import re
import typing
from typing import (Any, Callable, NamedTuple, Protocol, Union)
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
import warnings
from jax._src import ad_util
@ -643,6 +643,12 @@ def _to_logical_op_sharding(
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
return sharding._to_xla_hlo_sharding(aval.ndim)
def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
if s is None:
return None
assert isinstance(s, sharding_impls.XLACompatibleSharding)
return s.memory_kind
def lower_jaxpr_to_module(
module_name: str,
@ -712,6 +718,11 @@ def lower_jaxpr_to_module(
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
if arg_shardings is not None else None)
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks,
override_lowering_rules=override_lowering_rules,
@ -733,7 +744,9 @@ def lower_jaxpr_to_module(
result_shardings=result_op_shardings,
input_output_aliases=input_output_aliases,
arg_names=arg_names,
result_names=result_names)
result_names=result_names,
arg_memory_kinds=arg_memory_kinds,
result_memory_kinds=result_memory_kinds)
try:
if not ctx.module.operation.verify():
@ -857,6 +870,8 @@ def lower_jaxpr_to_fun(
api_name: str = "jit",
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
arg_memory_kinds: Sequence[str | None] | None = None,
result_memory_kinds: Sequence[str | None] | None = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -917,13 +932,15 @@ def lower_jaxpr_to_fun(
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
# Update the existing aliases to account for the new output values
input_output_aliases = [None if a is None
else a + num_output_tokens + num_tokens
for a in input_output_aliases]
for a in input_output_aliases] # type: ignore
if arg_shardings is not None:
token_shardings = [None] * (num_dim_vars + num_tokens)
arg_shardings = [*token_shardings, *arg_shardings]
@ -933,6 +950,13 @@ def lower_jaxpr_to_fun(
if replicated_args is not None:
token_replicated_args = [False] * (num_dim_vars + num_tokens)
replicated_args = [*token_replicated_args, *replicated_args]
if arg_memory_kinds is not None:
token_memory_kinds = [None] * (num_dim_vars + num_tokens)
arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds]
if result_memory_kinds is not None:
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
@ -940,6 +964,7 @@ def lower_jaxpr_to_fun(
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private")
ctx.symbol_table.insert(func_op)
ir_arg_shardings = None
if arg_shardings is not None:
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
@ -947,6 +972,12 @@ def lower_jaxpr_to_fun(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(in_avals, arg_shardings, input_types)])
del in_avals
ir_arg_memory_kinds = None
if arg_memory_kinds is not None:
ir_arg_memory_kinds = util.flatten(
[[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)])
ir_result_shardings = None
if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
@ -955,6 +986,11 @@ def lower_jaxpr_to_fun(
for a, s, types in zip(out_avals, result_shardings, output_types)])
del out_avals
ir_result_memory_kinds = None
if result_memory_kinds is not None:
ir_result_memory_kinds = util.flatten(
[[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)])
if (
replicated_args is not None
or ir_arg_shardings is not None
@ -1043,7 +1079,13 @@ def lower_jaxpr_to_fun(
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)),
if ir_arg_memory_kinds is not None:
flat_args = [
a if mk is None else wrap_with_memory_kind(a, mk, a_aval, is_input=True)
for a, mk, a_aval in zip(flat_args, ir_arg_memory_kinds, input_avals)]
_, token_args, unflattened_args = util.split_list(
util.unflatten(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens])
if create_tokens:
tokens_in = TokenSet.create(effects)
@ -1079,11 +1121,42 @@ def lower_jaxpr_to_fun(
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
if ir_result_memory_kinds is not None:
flat_outputs = [
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
func_dialect.ReturnOp(flat_outputs)
return func_op
def get_compute_type(memory_kind: str) -> str:
if memory_kind == 'tpu_hbm':
return 'dense'
elif memory_kind == 'unpinned_host':
return 'host'
raise ValueError(f'Unknown memory_kind: {memory_kind}')
def wrap_with_memory_kind(
x: ir.Value, memory_kind: str, aval_out: core.AbstractValue,
is_input: bool = False) -> ir.Value:
if aval_out is None:
result_type = x.type
else:
result_type = aval_to_ir_type(aval_out)
op = custom_call("annotate_device_placement", [result_type], [x],
has_side_effect=False,
api_version=1)
mka = get_compute_type(memory_kind)
dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)}
if is_input and mka == 'host':
dict_attr.update({"_xla_buffer_placement": ir.StringAttr.get("arg")})
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
return op.result
def _to_physical_op_sharding(
aval: core.AbstractValue | None, sharding: xc.HloSharding | None
) -> xc.OpSharding | None:

View File

@ -24,6 +24,7 @@ from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
import os
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
import numpy as np
@ -37,6 +38,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import config as jax_config
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_specs
@ -69,6 +71,13 @@ from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
# TODO(yashkatariya): Remove this flag after the host runtime is linked by
# default and works on cloud TPU.
_FETCH_MEMORY_KIND_ON_EXECUTABLE = jax_config.DEFINE_bool(
'jax_fetch_memory_kind_on_executable',
bool(os.getenv('JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE', '')),
help=("If True, will allow fetching memory kinds available on executable "
"and annotate Shardings with it."))
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
@ -1689,10 +1698,14 @@ class SemanticallyEqualShardings:
def __eq__(self, other):
if not isinstance(other, SemanticallyEqualShardings):
return False
return all(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
if (isinstance(s, sharding_impls.GSPMDSharding) and
isinstance(o, sharding_impls.GSPMDSharding))
else s == o for s, o in zip(self.shardings, other.shardings))
return all(
(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
and sharding_impls.are_mem_kind_of_shardings_equal(s, o))
if (isinstance(s, sharding_impls.GSPMDSharding) and
isinstance(o, sharding_impls.GSPMDSharding))
else s == o
for s, o in zip(self.shardings, other.shardings)
)
@weakref_lru_cache
@ -2208,34 +2221,40 @@ def _get_input_indices(
def get_gspmd_shardings_from_executable(
xla_executable, device_assignment: Sequence[xc.Device],
num_in_avals: int, num_out_avals: int
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding],
Sequence[sharding_impls.XLACompatibleSharding]]:
num_out_avals: int
) -> Sequence[sharding_impls.XLACompatibleSharding]:
from jax._src import pjit
if _FETCH_MEMORY_KIND_ON_EXECUTABLE.value:
try:
omk = xla_executable.get_output_memory_kinds()[0]
except:
omk = [None] * num_out_avals
else:
omk = [None] * num_out_avals
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
# just return SingleDeviceShardings since we know the computation is running
# only on 1 device.
if len(device_assignment) == 1:
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
return [ss] * num_in_avals, [ss] * num_out_avals
return [sharding_impls.SingleDeviceSharding(device_assignment[0], memory_kind=mk)
for mk in omk]
in_op_shardings, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i)
for i in in_op_shardings]
out_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, o)
for o in out_op_shardings]
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple.
# TODO(b/245667823): Remove this when XLA fixes this.
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals:
out_shardings_xla = out_shardings_xla * num_out_avals
assert len(out_shardings_xla) == num_out_avals, (
len(out_shardings_xla), num_out_avals)
return in_shardings_xla, out_shardings_xla
if len(out_op_shardings) == 1 and len(out_op_shardings) < num_out_avals:
out_op_shardings = out_op_shardings * num_out_avals # type: ignore
assert len(out_op_shardings) == num_out_avals == len(omk), (
len(out_op_shardings), num_out_avals, len(omk))
return [sharding_impls.GSPMDSharding(device_assignment, os, memory_kind=mk)
for os, mk in safe_zip(out_op_shardings, omk)]
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
@ -2258,39 +2277,38 @@ _ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding)
def _register_out_sharding_handler(
sharding_cls: type[_ShardingT],
handler: Callable[[xc.OpSharding, _ShardingT], _ShardingT],
handler: Callable[[sharding_impls.GSPMDSharding, _ShardingT], _ShardingT],
) -> None:
_orig_out_sharding_handlers[sharding_cls] = handler
def _gspmd_to_named_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
op_sharding, self.mesh)[0]
out_s._hlo_sharding, orig_in_s.mesh)[0]
return create_mesh_pspec_sharding(
self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
orig_in_s.mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)
_register_out_sharding_handler(
sharding_impls.NamedSharding, _gspmd_to_named_sharding
)
sharding_impls.NamedSharding, _gspmd_to_named_sharding)
def _gspmd_to_positional_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
return sharding_impls._op_sharding_to_pos_sharding(
op_sharding, self._device_assignment)
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
_register_out_sharding_handler(
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding
)
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding)
def _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_s, orig_aval, are_out_sharding_from_xla):
out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla):
out = []
orig_handler = _orig_out_sharding_handlers[type(orig_s)]
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
are_out_sharding_from_xla):
if isinstance(o, sharding_impls.GSPMDSharding):
@ -2302,10 +2320,11 @@ def _get_out_sharding_from_orig_sharding(
if (orig_aval is not None and out_aval is not None and
out_aval.ndim == orig_aval.ndim and
sharding_impls.are_op_shardings_equal(
o._hlo_sharding, orig_s._to_xla_hlo_sharding(orig_aval.ndim))):
out.append((orig_s, False))
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and
sharding_impls.are_mem_kind_of_shardings_equal(o, orig_in_s)):
out.append((orig_in_s, False))
else:
out.append((orig_handler(o._hlo_sharding, orig_s), False))
out.append((orig_handler(o, orig_in_s), False))
except:
out.append((o, from_xla))
else:
@ -2319,17 +2338,17 @@ def maybe_get_orig_out_sharding(
return ([o._original_sharding for o in out_shardings],
(False,) * len(out_shardings))
orig_s = None
orig_in_s = None
orig_aval = None
for i, aval in safe_zip(in_shardings, in_avals):
oi = getattr(i, '_original_sharding', None)
if type(oi) in _orig_out_sharding_handlers:
orig_s = oi
orig_in_s = oi
orig_aval = aval
break
if orig_s is not None:
if orig_in_s is not None:
return zip(*_get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_s, orig_aval, are_out_shardings_from_xla))
out_shardings, out_avals, orig_in_s, orig_aval, are_out_shardings_from_xla))
return out_shardings, are_out_shardings_from_xla
@ -2521,9 +2540,8 @@ class UnloadedMeshExecutable:
assert mesh is None
device_assignment = da.device_assignment if isinstance( # type: ignore
da, _DeviceAssignment) else da
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, # type: ignore
len(global_in_avals), len(global_out_avals))
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
@ -2532,9 +2550,10 @@ class UnloadedMeshExecutable:
out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
if not op_shardings.are_op_shardings_equal(
xla_s._to_xla_hlo_sharding(aval.ndim), # type: ignore
orig._to_xla_hlo_sharding(aval.ndim)): # type: ignore
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
not sharding_impls.are_mem_kind_of_shardings_equal(xla_s, orig)): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
@ -2823,11 +2842,12 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
@lru_cache
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None
) -> sharding_impls.NamedSharding:
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
memory_kind: Optional[str] = None) -> sharding_impls.NamedSharding:
if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec)
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
memory_kind=memory_kind)
def check_device_backend_on_shardings(shardings) -> bool:
@ -2852,6 +2872,12 @@ def check_gda_or_array_xla_sharding_match(
if not isinstance(arg, ArrayImpl):
continue
# Raise memory kind mismatch error even if the arg is uncommitted.
if not sharding_impls.are_mem_kind_of_shardings_equal(arg.sharding, xs):
errors.append(
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
f"arg {name} with shape: {arg.aval.str_short()}")
# No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call.
if (not check_device_backend_on_shardings([xs]) and

View File

@ -1741,9 +1741,11 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
args_flat):
@lru_cache
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
if not op_shardings.are_op_shardings_equal(
if (not op_shardings.are_op_shardings_equal(
in_sharding._to_xla_hlo_sharding(ndim),
xmap_sharding._to_xla_hlo_sharding(ndim)):
xmap_sharding._to_xla_hlo_sharding(ndim)) or
not sharding_impls.are_mem_kind_of_shardings_equal(
in_sharding, xmap_sharding)):
raise ValueError(
f"Got an input {arr_flavor} to xmap with different partitioning than "
"specified in xmap. The partitioning must match. "

View File

@ -59,7 +59,8 @@ from jax._src.sharding_impls import (
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding,
are_mem_kind_of_shardings_equal)
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
@ -261,7 +262,6 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
donate_argnums, tree_util.default_registry, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
@ -1092,6 +1092,14 @@ def _resolve_in_shardings(
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
f'Got arg shape: {arg.shape}, arg value: {arg}')
if not is_unspecified(arg_s):
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
if not are_mem_kind_of_shardings_equal(pjit_in_s, arg_s): # type: ignore
raise ValueError(
'Memory kinds passed to jax.jit does not match memory kind on the'
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
f'arg memory kind: {arg_s.memory_kind} for arg shape: {arg.shape}') # type: ignore
if (committed and
not isinstance(arg_s, PmapSharding) and
not op_shardings.are_op_shardings_equal(
@ -1232,8 +1240,9 @@ class SameDeviceAssignmentTuple:
s = getattr(s, "_original_sharding", s)
o = getattr(o, "_original_sharding", o)
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding):
eq.append(op_shardings.are_op_shardings_equal(
s._hlo_sharding, o._hlo_sharding))
eq.append(
op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
and are_mem_kind_of_shardings_equal(s, o))
else:
eq.append(s == o)
return all(eq) and self.device_assignment == other.device_assignment
@ -1944,7 +1953,8 @@ def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int,
device_or_backend_set: bool = False) -> GSPMDSharding:
if isinstance(s, GSPMDSharding):
return s
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim))
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
memory_kind=s.memory_kind)
gs._original_sharding = s
if device_or_backend_set:
gs._original_sharding._device_backend = device_or_backend_set

View File

@ -93,6 +93,10 @@ class Sharding:
"""Returns the memory kind of the sharding."""
raise NotImplementedError('Subclasses should implement this method.')
def with_memory_kind(self, kind: str) -> Sharding:
"""Returns a new Sharding instance with the specified memory kind."""
raise NotImplementedError('Subclasses should implement this method')
#############################################################################
# Default implementations below that all subclasses will inherit.

View File

@ -22,7 +22,7 @@ import enum
import functools
import itertools
import math
from typing import Any, NamedTuple, Union, cast
from typing import Any, NamedTuple, Union, cast, Optional
from jax._src import mesh as mesh_lib
from jax._src.op_shardings import (
@ -109,7 +109,7 @@ class XLACompatibleSharding(sharding.Sharding):
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
other._to_xla_hlo_sharding(ndim))
and self._device_assignment == other._device_assignment and
self.memory_kind == other.memory_kind)
are_mem_kind_of_shardings_equal(self, other))
# NotImplementedError is raised by PmapSharding because it can't lower
# to OpSharding. So if `other` is a PmapSharding, default to a strict
# equality check.
@ -154,6 +154,51 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out
# This is an optimization to get the memory kinds associated with the local
# devices. This is because in McJAX, checking if the memory kind input by user
# is correct requires doing `local_devices()[0].memory(inp)` which is expensive
# because calculating the local devices is expensive. So cache on xc.Client and
# find all the memories associated only once since the client does not change.
@functools.lru_cache
def _mem_kinds(client: xc.Client) -> set[str]:
return set(m.kind for m in client.local_devices()[0].addressable_memories())
@functools.lru_cache
def _default_mem_kind(client: xc.Client) -> str | None:
try:
return client.local_devices()[0].default_memory().kind
except:
return None
def _check_mem_kind(device: xc.Device, mk):
mem_kinds = _mem_kinds(device.client)
if mk not in mem_kinds:
raise ValueError(
f'Could not find memory addressable by device {device.device_kind}.'
f' Device {device.device_kind} can address the following memory kinds:'
f' {mem_kinds}. Got memory kind: {mk}')
def get_canonicalized_memory_kind(s: XLACompatibleSharding) -> Optional[str]:
# TODO(yashkatariya): Remove try;except when CPU and GPU support memories.
try:
client = s._device_assignment[0].client
return _default_mem_kind(client) if s.memory_kind is None else s.memory_kind
except:
return None
# TODO(yashkatariya): Remove this when the canonicalization happens in __init__
# of Shardings. That can be done after OSS support is also added for memories.
def are_mem_kind_of_shardings_equal(s1: XLACompatibleSharding,
s2: XLACompatibleSharding) -> bool:
if s1.memory_kind is None and s2.memory_kind is None:
return True
mk1 = get_canonicalized_memory_kind(s1)
mk2 = get_canonicalized_memory_kind(s2)
return mk1 == mk2
@use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding):
r"""A :class:`NamedSharding` expresses sharding using named axes.
@ -207,16 +252,13 @@ class NamedSharding(XLACompatibleSharding):
self._preprocess()
def __reduce__(self):
return (
type(self),
(self.mesh, self.spec),
{'memory_kind': self.memory_kind},
)
return (type(self), (self.mesh, self.spec),
{'memory_kind': self.memory_kind})
def _preprocess(self):
if self.memory_kind is not None:
# Will error if memory_kind does not exist on the device.
self.mesh.devices.flat[0].memory(self.memory_kind)
_check_mem_kind(self.mesh.devices.flat[0], self.memory_kind)
# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
@ -239,7 +281,8 @@ class NamedSharding(XLACompatibleSharding):
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
self._hash = hash((self.mesh, get_canonicalized_memory_kind(self),
self._parsed_pspec))
return self._hash
def __eq__(self, other):
@ -248,11 +291,11 @@ class NamedSharding(XLACompatibleSharding):
if id(self) == id(other):
return True
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
if (id(self.mesh) == id(other.mesh) and
self.memory_kind == other.memory_kind and parsed_pspec_equal):
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
if (id(self.mesh) == id(other.mesh) and mem_kind_equal and
parsed_pspec_equal):
return True
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind
and parsed_pspec_equal)
return self.mesh == other.mesh and mem_kind_equal and parsed_pspec_equal
def is_compatible_aval(self, aval_shape: Shape):
assert self._parsed_pspec is not None
@ -267,7 +310,7 @@ class NamedSharding(XLACompatibleSharding):
@classmethod
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
@property
def device_set(self) -> set[Device]:
@ -300,6 +343,9 @@ class NamedSharding(XLACompatibleSharding):
num_partitions *= mesh_shape[name]
return num_partitions == 1
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
def _get_sharding_spec(self, num_dimensions, axis_ctx):
assert self._parsed_pspec is not None
array_mapping = get_array_mapping(self._parsed_pspec)
@ -362,7 +408,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._device, self._memory_kind))
self._hash = hash((self._device, get_canonicalized_memory_kind(self)))
return self._hash
def __eq__(self, other):
@ -371,7 +417,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
if id(self) == id(other):
return True
return (self._device == other._device and
self._memory_kind == other._memory_kind)
are_mem_kind_of_shardings_equal(self, other))
@property
def device_set(self) -> set[Device]:
@ -381,6 +427,9 @@ class SingleDeviceSharding(XLACompatibleSharding):
def memory_kind(self) -> str | None:
return self._memory_kind
def with_memory_kind(self, kind: str) -> SingleDeviceSharding:
return SingleDeviceSharding(self._device, memory_kind=kind)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
return {self._device: (slice(None),) * len(global_shape)}
@ -493,6 +542,9 @@ class PmapSharding(XLACompatibleSharding):
def memory_kind(self):
return None
def with_memory_kind(self, kind: str):
return NotImplementedError("pmap does not support memories.")
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")
@ -524,13 +576,15 @@ class PmapSharding(XLACompatibleSharding):
def _op_sharding_to_pos_sharding(
op_sharding: xc.OpSharding | xc.HloSharding,
device_assignment: Sequence[xc.Device]) -> PositionalSharding:
op_sharding: Union[xc.OpSharding, xc.HloSharding],
device_assignment: Sequence[xc.Device],
memory_kind: Optional[str] = None) -> PositionalSharding:
if isinstance(op_sharding, xc.HloSharding):
op_sharding = op_sharding.to_proto() # type: ignore
if op_sharding.type == xc.OpSharding.Type.REPLICATED:
return PositionalSharding(device_assignment).replicate()
return PositionalSharding(
device_assignment, memory_kind=memory_kind).replicate()
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
replicate_on_last_tile_dim = True
@ -543,7 +597,11 @@ def _op_sharding_to_pos_sharding(
name = device_assignment[0].platform.upper()
ids = np.array([DeviceIdSet(name, i)
for i in op_sharding.tile_assignment_devices])
p = PositionalSharding._remake(tuple(device_assignment), ids)
if memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(device_assignment[0], memory_kind)
p = PositionalSharding._remake(tuple(device_assignment), ids,
memory_kind=memory_kind)
p = p.reshape(op_sharding.tile_assignment_dimensions)
if replicate_on_last_tile_dim:
p = p.replicate(-1, keepdims=False)
@ -569,7 +627,7 @@ class PositionalSharding(XLACompatibleSharding):
dtype='object').reshape(devices.shape)
if self._memory_kind is not None:
# Will error if memory_kind does not exist on the device.
self._devices[0].memory(self._memory_kind)
_check_mem_kind(self._devices[0], self._memory_kind)
@property
def shape(self):
@ -615,7 +673,7 @@ class PositionalSharding(XLACompatibleSharding):
def __hash__(self) -> int:
if not hasattr(self, '_hash'):
self._hash = hash((self._devices, self._memory_kind))
self._hash = hash((self._devices, get_canonicalized_memory_kind(self)))
return self._hash
def __eq__(self, other) -> bool:
@ -624,11 +682,11 @@ class PositionalSharding(XLACompatibleSharding):
if id(self) == id(other):
return True
all_ids_equal = np.array_equal(self._ids,other._ids)
if (id(self._devices) == id(other._devices) and
self._memory_kind == other._memory_kind and all_ids_equal):
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
if (id(self._devices) == id(other._devices) and mem_kind_equal and
all_ids_equal):
return True
return (self._devices == other._devices and
self._memory_kind == other._memory_kind and all_ids_equal)
return self._devices == other._devices and mem_kind_equal and all_ids_equal
# Sharding interface
@ -640,6 +698,9 @@ class PositionalSharding(XLACompatibleSharding):
def memory_kind(self) -> str | None:
return self._memory_kind
def with_memory_kind(self, kind: str) -> PositionalSharding:
return PositionalSharding(self._devices, memory_kind=kind)
@functools.cached_property
def is_fully_replicated(self) -> bool:
return self.shape == (1,) * self.ndim
@ -717,12 +778,14 @@ class GSPMDSharding(XLACompatibleSharding):
self._hlo_sharding = op_sharding
self._memory_kind = memory_kind
def _preprocess(self):
if self.memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(self._devices[0], self.memory_kind)
def __reduce__(self):
return (
type(self),
(self._devices, self._hlo_sharding.to_proto()),
{'memory_kind': self._memory_kind},
)
return (type(self), (self._devices, self._hlo_sharding.to_proto()),
{'memory_kind': self._memory_kind})
@functools.cached_property
def _hlo_sharding_hash(self):
@ -735,12 +798,12 @@ class GSPMDSharding(XLACompatibleSharding):
return True
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self._devices == other._devices and
self._memory_kind == other._memory_kind)
are_mem_kind_of_shardings_equal(self, other))
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._devices, self._hlo_sharding_hash,
self._memory_kind))
get_canonicalized_memory_kind(self)))
return self._hash
def __repr__(self):
@ -763,6 +826,9 @@ class GSPMDSharding(XLACompatibleSharding):
def memory_kind(self) -> str | None:
return self._memory_kind
def with_memory_kind(self, kind: str) -> GSPMDSharding:
return GSPMDSharding(self._devices, self._hlo_sharding, memory_kind=kind)
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
self.shard_shape(global_shape) # raises a good error message