mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Memories] Add Memories support to jax.jit and jax.device_put!
These are the following changes: * Add a temporary flag (`JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE`) (should not be used by user but needed in C++ in pjrt-ifrt code) on whether to fetch memory kinds from executable. If it is set to True, the host runtime dep needs to be linked in and should also work in OSS (more work needs to happen for that). So only the test sets it to True for now until jax memories is under development. * Add with_memory_kind method on Sharding to allow for easier creation of shardings with different memory kind. * Add lowering rules for device_put and jax.jit. * For device_put, we always add the annotation that describes a transfer to a memory and a sharding annotation. * For jax.jit, if the argument is on host memory, it will have an extra attribute _xla_buffer_placement. * Handle the correct output sharding in pxla.py by extracting the memory kind from the executable. * Handle the caching of pjit caches by canonicalizing the memory_kinds so that `NS(mesh, pspec) == NS(mesh, pspec, memory_kind='tpu_hbm')`. Also canonicalize memory_kind in `__hash__` and `__eq__` of shardings. * This is to not change the StableHLO to include device placement annotations right now since the host aware passes are not enabled by default and the work is under progress to make it work everywhere. PiperOrigin-RevId: 553833344
This commit is contained in:
parent
d7940ee9a1
commit
4fb8cdb019
@ -345,6 +345,12 @@ def jaxpr_shardings(
|
|||||||
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
||||||
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
|
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
|
||||||
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
|
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
|
||||||
|
elif eqn.primitive is device_put_p:
|
||||||
|
s = eqn.params['device']
|
||||||
|
if isinstance(s, XLACompatibleSharding) and s.memory_kind is not None:
|
||||||
|
source_info = SourceInfo(source_info_util.summarize(eqn.source_info),
|
||||||
|
eqn.primitive.name)
|
||||||
|
yield (s, source_info)
|
||||||
for subjaxpr in core.subjaxprs(jaxpr):
|
for subjaxpr in core.subjaxprs(jaxpr):
|
||||||
yield from jaxpr_shardings(subjaxpr)
|
yield from jaxpr_shardings(subjaxpr)
|
||||||
|
|
||||||
@ -699,5 +705,12 @@ ad.deflinear2(device_put_p, device_put_transpose_rule)
|
|||||||
batching.defvectorized(device_put_p)
|
batching.defvectorized(device_put_p)
|
||||||
|
|
||||||
def _device_put_lowering(ctx, x, *, device, src):
|
def _device_put_lowering(ctx, x, *, device, src):
|
||||||
|
if isinstance(device, XLACompatibleSharding) and device.memory_kind is not None:
|
||||||
|
aval, = ctx.avals_in
|
||||||
|
out_aval, = ctx.avals_out
|
||||||
|
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
|
||||||
|
x = mlir.wrap_with_sharding_op(
|
||||||
|
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||||
|
return [x]
|
||||||
return [x]
|
return [x]
|
||||||
mlir.register_lowering(device_put_p, _device_put_lowering)
|
mlir.register_lowering(device_put_p, _device_put_lowering)
|
||||||
|
@ -25,7 +25,7 @@ import itertools
|
|||||||
import operator
|
import operator
|
||||||
import re
|
import re
|
||||||
import typing
|
import typing
|
||||||
from typing import (Any, Callable, NamedTuple, Protocol, Union)
|
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
@ -643,6 +643,12 @@ def _to_logical_op_sharding(
|
|||||||
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
||||||
return sharding._to_xla_hlo_sharding(aval.ndim)
|
return sharding._to_xla_hlo_sharding(aval.ndim)
|
||||||
|
|
||||||
|
def _get_mem_kind(s: Optional[XLACompatibleSharding]) -> Optional[str]:
|
||||||
|
if s is None:
|
||||||
|
return None
|
||||||
|
assert isinstance(s, sharding_impls.XLACompatibleSharding)
|
||||||
|
return s.memory_kind
|
||||||
|
|
||||||
|
|
||||||
def lower_jaxpr_to_module(
|
def lower_jaxpr_to_module(
|
||||||
module_name: str,
|
module_name: str,
|
||||||
@ -712,6 +718,11 @@ def lower_jaxpr_to_module(
|
|||||||
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
|
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
|
||||||
if result_shardings is not None else result_shardings)
|
if result_shardings is not None else result_shardings)
|
||||||
|
|
||||||
|
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
|
||||||
|
if arg_shardings is not None else None)
|
||||||
|
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
||||||
|
if result_shardings is not None else None)
|
||||||
|
|
||||||
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
||||||
keepalives, channel_iter, host_callbacks,
|
keepalives, channel_iter, host_callbacks,
|
||||||
override_lowering_rules=override_lowering_rules,
|
override_lowering_rules=override_lowering_rules,
|
||||||
@ -733,7 +744,9 @@ def lower_jaxpr_to_module(
|
|||||||
result_shardings=result_op_shardings,
|
result_shardings=result_op_shardings,
|
||||||
input_output_aliases=input_output_aliases,
|
input_output_aliases=input_output_aliases,
|
||||||
arg_names=arg_names,
|
arg_names=arg_names,
|
||||||
result_names=result_names)
|
result_names=result_names,
|
||||||
|
arg_memory_kinds=arg_memory_kinds,
|
||||||
|
result_memory_kinds=result_memory_kinds)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not ctx.module.operation.verify():
|
if not ctx.module.operation.verify():
|
||||||
@ -857,6 +870,8 @@ def lower_jaxpr_to_fun(
|
|||||||
api_name: str = "jit",
|
api_name: str = "jit",
|
||||||
arg_names: Sequence[str | None] | None = None,
|
arg_names: Sequence[str | None] | None = None,
|
||||||
result_names: Sequence[str | None] | None = None,
|
result_names: Sequence[str | None] | None = None,
|
||||||
|
arg_memory_kinds: Sequence[str | None] | None = None,
|
||||||
|
result_memory_kinds: Sequence[str | None] | None = None,
|
||||||
) -> func_dialect.FuncOp:
|
) -> func_dialect.FuncOp:
|
||||||
"""Lowers jaxpr and its callees to an IR function.
|
"""Lowers jaxpr and its callees to an IR function.
|
||||||
|
|
||||||
@ -917,13 +932,15 @@ def lower_jaxpr_to_fun(
|
|||||||
input_types = [*dim_var_types, *token_types, *input_types]
|
input_types = [*dim_var_types, *token_types, *input_types]
|
||||||
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
|
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
|
||||||
output_types = [*output_token_types, *token_types, *output_types]
|
output_types = [*output_token_types, *token_types, *output_types]
|
||||||
|
|
||||||
if input_output_aliases is not None:
|
if input_output_aliases is not None:
|
||||||
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
|
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
|
||||||
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
|
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
|
||||||
# Update the existing aliases to account for the new output values
|
# Update the existing aliases to account for the new output values
|
||||||
input_output_aliases = [None if a is None
|
input_output_aliases = [None if a is None
|
||||||
else a + num_output_tokens + num_tokens
|
else a + num_output_tokens + num_tokens
|
||||||
for a in input_output_aliases]
|
for a in input_output_aliases] # type: ignore
|
||||||
|
|
||||||
if arg_shardings is not None:
|
if arg_shardings is not None:
|
||||||
token_shardings = [None] * (num_dim_vars + num_tokens)
|
token_shardings = [None] * (num_dim_vars + num_tokens)
|
||||||
arg_shardings = [*token_shardings, *arg_shardings]
|
arg_shardings = [*token_shardings, *arg_shardings]
|
||||||
@ -933,6 +950,13 @@ def lower_jaxpr_to_fun(
|
|||||||
if replicated_args is not None:
|
if replicated_args is not None:
|
||||||
token_replicated_args = [False] * (num_dim_vars + num_tokens)
|
token_replicated_args = [False] * (num_dim_vars + num_tokens)
|
||||||
replicated_args = [*token_replicated_args, *replicated_args]
|
replicated_args = [*token_replicated_args, *replicated_args]
|
||||||
|
if arg_memory_kinds is not None:
|
||||||
|
token_memory_kinds = [None] * (num_dim_vars + num_tokens)
|
||||||
|
arg_memory_kinds = [*token_memory_kinds, *arg_memory_kinds]
|
||||||
|
if result_memory_kinds is not None:
|
||||||
|
token_memory_kinds = [None] * (num_tokens + num_output_tokens)
|
||||||
|
result_memory_kinds = [*token_memory_kinds, *result_memory_kinds]
|
||||||
|
|
||||||
flat_input_types = util.flatten(input_types)
|
flat_input_types = util.flatten(input_types)
|
||||||
flat_output_types = util.flatten(output_types)
|
flat_output_types = util.flatten(output_types)
|
||||||
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
|
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
|
||||||
@ -940,6 +964,7 @@ def lower_jaxpr_to_fun(
|
|||||||
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
|
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
|
||||||
"public" if public else "private")
|
"public" if public else "private")
|
||||||
ctx.symbol_table.insert(func_op)
|
ctx.symbol_table.insert(func_op)
|
||||||
|
|
||||||
ir_arg_shardings = None
|
ir_arg_shardings = None
|
||||||
if arg_shardings is not None:
|
if arg_shardings is not None:
|
||||||
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
|
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
|
||||||
@ -947,6 +972,12 @@ def lower_jaxpr_to_fun(
|
|||||||
[[_to_physical_op_sharding(a, s)] * len(types)
|
[[_to_physical_op_sharding(a, s)] * len(types)
|
||||||
for a, s, types in zip(in_avals, arg_shardings, input_types)])
|
for a, s, types in zip(in_avals, arg_shardings, input_types)])
|
||||||
del in_avals
|
del in_avals
|
||||||
|
|
||||||
|
ir_arg_memory_kinds = None
|
||||||
|
if arg_memory_kinds is not None:
|
||||||
|
ir_arg_memory_kinds = util.flatten(
|
||||||
|
[[mk] * len(types) for mk, types in zip(arg_memory_kinds, input_types)])
|
||||||
|
|
||||||
ir_result_shardings = None
|
ir_result_shardings = None
|
||||||
if result_shardings is not None:
|
if result_shardings is not None:
|
||||||
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
|
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
|
||||||
@ -955,6 +986,11 @@ def lower_jaxpr_to_fun(
|
|||||||
for a, s, types in zip(out_avals, result_shardings, output_types)])
|
for a, s, types in zip(out_avals, result_shardings, output_types)])
|
||||||
del out_avals
|
del out_avals
|
||||||
|
|
||||||
|
ir_result_memory_kinds = None
|
||||||
|
if result_memory_kinds is not None:
|
||||||
|
ir_result_memory_kinds = util.flatten(
|
||||||
|
[[mk] * len(types) for mk, types in zip(result_memory_kinds, output_types)])
|
||||||
|
|
||||||
if (
|
if (
|
||||||
replicated_args is not None
|
replicated_args is not None
|
||||||
or ir_arg_shardings is not None
|
or ir_arg_shardings is not None
|
||||||
@ -1043,7 +1079,13 @@ def lower_jaxpr_to_fun(
|
|||||||
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
|
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
|
||||||
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
|
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
|
||||||
|
|
||||||
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)),
|
if ir_arg_memory_kinds is not None:
|
||||||
|
flat_args = [
|
||||||
|
a if mk is None else wrap_with_memory_kind(a, mk, a_aval, is_input=True)
|
||||||
|
for a, mk, a_aval in zip(flat_args, ir_arg_memory_kinds, input_avals)]
|
||||||
|
|
||||||
|
_, token_args, unflattened_args = util.split_list(
|
||||||
|
util.unflatten(flat_args, map(len, input_types)),
|
||||||
[num_dim_vars, num_tokens])
|
[num_dim_vars, num_tokens])
|
||||||
if create_tokens:
|
if create_tokens:
|
||||||
tokens_in = TokenSet.create(effects)
|
tokens_in = TokenSet.create(effects)
|
||||||
@ -1079,11 +1121,42 @@ def lower_jaxpr_to_fun(
|
|||||||
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
|
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
|
||||||
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
|
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
|
||||||
|
|
||||||
|
if ir_result_memory_kinds is not None:
|
||||||
|
flat_outputs = [
|
||||||
|
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
|
||||||
|
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
|
||||||
|
|
||||||
func_dialect.ReturnOp(flat_outputs)
|
func_dialect.ReturnOp(flat_outputs)
|
||||||
|
|
||||||
return func_op
|
return func_op
|
||||||
|
|
||||||
|
|
||||||
|
def get_compute_type(memory_kind: str) -> str:
|
||||||
|
if memory_kind == 'tpu_hbm':
|
||||||
|
return 'dense'
|
||||||
|
elif memory_kind == 'unpinned_host':
|
||||||
|
return 'host'
|
||||||
|
raise ValueError(f'Unknown memory_kind: {memory_kind}')
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_with_memory_kind(
|
||||||
|
x: ir.Value, memory_kind: str, aval_out: core.AbstractValue,
|
||||||
|
is_input: bool = False) -> ir.Value:
|
||||||
|
if aval_out is None:
|
||||||
|
result_type = x.type
|
||||||
|
else:
|
||||||
|
result_type = aval_to_ir_type(aval_out)
|
||||||
|
op = custom_call("annotate_device_placement", [result_type], [x],
|
||||||
|
has_side_effect=False,
|
||||||
|
api_version=1)
|
||||||
|
mka = get_compute_type(memory_kind)
|
||||||
|
dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)}
|
||||||
|
if is_input and mka == 'host':
|
||||||
|
dict_attr.update({"_xla_buffer_placement": ir.StringAttr.get("arg")})
|
||||||
|
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
|
||||||
|
return op.result
|
||||||
|
|
||||||
|
|
||||||
def _to_physical_op_sharding(
|
def _to_physical_op_sharding(
|
||||||
aval: core.AbstractValue | None, sharding: xc.HloSharding | None
|
aval: core.AbstractValue | None, sharding: xc.HloSharding | None
|
||||||
) -> xc.OpSharding | None:
|
) -> xc.OpSharding | None:
|
||||||
|
@ -24,6 +24,7 @@ from functools import partial, lru_cache, cached_property
|
|||||||
import itertools as it
|
import itertools as it
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
|
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -37,6 +38,7 @@ from jax._src import dispatch
|
|||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import effects
|
from jax._src import effects
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
|
from jax._src import config as jax_config
|
||||||
from jax._src import mesh as mesh_lib
|
from jax._src import mesh as mesh_lib
|
||||||
from jax._src import op_shardings
|
from jax._src import op_shardings
|
||||||
from jax._src import sharding_specs
|
from jax._src import sharding_specs
|
||||||
@ -69,6 +71,13 @@ from jax._src.util import (safe_map, safe_zip, partition_list,
|
|||||||
wrap_name, tuple_delete, distributed_debug_log,
|
wrap_name, tuple_delete, distributed_debug_log,
|
||||||
unzip2, HashableFunction, weakref_lru_cache)
|
unzip2, HashableFunction, weakref_lru_cache)
|
||||||
|
|
||||||
|
# TODO(yashkatariya): Remove this flag after the host runtime is linked by
|
||||||
|
# default and works on cloud TPU.
|
||||||
|
_FETCH_MEMORY_KIND_ON_EXECUTABLE = jax_config.DEFINE_bool(
|
||||||
|
'jax_fetch_memory_kind_on_executable',
|
||||||
|
bool(os.getenv('JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE', '')),
|
||||||
|
help=("If True, will allow fetching memory kinds available on executable "
|
||||||
|
"and annotate Shardings with it."))
|
||||||
|
|
||||||
# Built in Python lists don't support weak refs but subclasses of lists do.
|
# Built in Python lists don't support weak refs but subclasses of lists do.
|
||||||
class WeakRefList(list):
|
class WeakRefList(list):
|
||||||
@ -1689,10 +1698,14 @@ class SemanticallyEqualShardings:
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, SemanticallyEqualShardings):
|
if not isinstance(other, SemanticallyEqualShardings):
|
||||||
return False
|
return False
|
||||||
return all(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
|
return all(
|
||||||
if (isinstance(s, sharding_impls.GSPMDSharding) and
|
(op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
|
||||||
isinstance(o, sharding_impls.GSPMDSharding))
|
and sharding_impls.are_mem_kind_of_shardings_equal(s, o))
|
||||||
else s == o for s, o in zip(self.shardings, other.shardings))
|
if (isinstance(s, sharding_impls.GSPMDSharding) and
|
||||||
|
isinstance(o, sharding_impls.GSPMDSharding))
|
||||||
|
else s == o
|
||||||
|
for s, o in zip(self.shardings, other.shardings)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@weakref_lru_cache
|
@weakref_lru_cache
|
||||||
@ -2208,34 +2221,40 @@ def _get_input_indices(
|
|||||||
|
|
||||||
def get_gspmd_shardings_from_executable(
|
def get_gspmd_shardings_from_executable(
|
||||||
xla_executable, device_assignment: Sequence[xc.Device],
|
xla_executable, device_assignment: Sequence[xc.Device],
|
||||||
num_in_avals: int, num_out_avals: int
|
num_out_avals: int
|
||||||
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding],
|
) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||||
Sequence[sharding_impls.XLACompatibleSharding]]:
|
|
||||||
from jax._src import pjit
|
from jax._src import pjit
|
||||||
|
|
||||||
|
if _FETCH_MEMORY_KIND_ON_EXECUTABLE.value:
|
||||||
|
try:
|
||||||
|
omk = xla_executable.get_output_memory_kinds()[0]
|
||||||
|
except:
|
||||||
|
omk = [None] * num_out_avals
|
||||||
|
else:
|
||||||
|
omk = [None] * num_out_avals
|
||||||
|
|
||||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||||
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
||||||
# just return SingleDeviceShardings since we know the computation is running
|
# just return SingleDeviceShardings since we know the computation is running
|
||||||
# only on 1 device.
|
# only on 1 device.
|
||||||
if len(device_assignment) == 1:
|
if len(device_assignment) == 1:
|
||||||
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
|
return [sharding_impls.SingleDeviceSharding(device_assignment[0], memory_kind=mk)
|
||||||
return [ss] * num_in_avals, [ss] * num_out_avals
|
for mk in omk]
|
||||||
|
|
||||||
in_op_shardings, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
|
_, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
|
||||||
|
|
||||||
in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i)
|
|
||||||
for i in in_op_shardings]
|
|
||||||
out_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, o)
|
|
||||||
for o in out_op_shardings]
|
|
||||||
# This condition happens when all the elements in the output tuple have the
|
# This condition happens when all the elements in the output tuple have the
|
||||||
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
||||||
# put the sharding on ROOT instead of the tuple.
|
# put the sharding on ROOT instead of the tuple.
|
||||||
# TODO(b/245667823): Remove this when XLA fixes this.
|
# TODO(b/245667823): Remove this when XLA fixes this.
|
||||||
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals:
|
if len(out_op_shardings) == 1 and len(out_op_shardings) < num_out_avals:
|
||||||
out_shardings_xla = out_shardings_xla * num_out_avals
|
out_op_shardings = out_op_shardings * num_out_avals # type: ignore
|
||||||
assert len(out_shardings_xla) == num_out_avals, (
|
|
||||||
len(out_shardings_xla), num_out_avals)
|
assert len(out_op_shardings) == num_out_avals == len(omk), (
|
||||||
return in_shardings_xla, out_shardings_xla
|
len(out_op_shardings), num_out_avals, len(omk))
|
||||||
|
|
||||||
|
return [sharding_impls.GSPMDSharding(device_assignment, os, memory_kind=mk)
|
||||||
|
for os, mk in safe_zip(out_op_shardings, omk)]
|
||||||
|
|
||||||
|
|
||||||
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
|
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
|
||||||
@ -2258,39 +2277,38 @@ _ShardingT = TypeVar("_ShardingT", bound=sharding_impls.XLACompatibleSharding)
|
|||||||
|
|
||||||
def _register_out_sharding_handler(
|
def _register_out_sharding_handler(
|
||||||
sharding_cls: type[_ShardingT],
|
sharding_cls: type[_ShardingT],
|
||||||
handler: Callable[[xc.OpSharding, _ShardingT], _ShardingT],
|
handler: Callable[[sharding_impls.GSPMDSharding, _ShardingT], _ShardingT],
|
||||||
) -> None:
|
) -> None:
|
||||||
_orig_out_sharding_handlers[sharding_cls] = handler
|
_orig_out_sharding_handlers[sharding_cls] = handler
|
||||||
|
|
||||||
|
|
||||||
def _gspmd_to_named_sharding(
|
def _gspmd_to_named_sharding(
|
||||||
op_sharding: xc.OpSharding,
|
out_s: sharding_impls.GSPMDSharding,
|
||||||
self: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
||||||
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
|
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
|
||||||
op_sharding, self.mesh)[0]
|
out_s._hlo_sharding, orig_in_s.mesh)[0]
|
||||||
return create_mesh_pspec_sharding(
|
return create_mesh_pspec_sharding(
|
||||||
self.mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
|
orig_in_s.mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
|
||||||
|
out_s.memory_kind)
|
||||||
|
|
||||||
_register_out_sharding_handler(
|
_register_out_sharding_handler(
|
||||||
sharding_impls.NamedSharding, _gspmd_to_named_sharding
|
sharding_impls.NamedSharding, _gspmd_to_named_sharding)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _gspmd_to_positional_sharding(
|
def _gspmd_to_positional_sharding(
|
||||||
op_sharding: xc.OpSharding,
|
out_s: sharding_impls.GSPMDSharding,
|
||||||
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
|
orig_in_s: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
|
||||||
return sharding_impls._op_sharding_to_pos_sharding(
|
return sharding_impls._op_sharding_to_pos_sharding(
|
||||||
op_sharding, self._device_assignment)
|
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)
|
||||||
|
|
||||||
_register_out_sharding_handler(
|
_register_out_sharding_handler(
|
||||||
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding
|
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_out_sharding_from_orig_sharding(
|
def _get_out_sharding_from_orig_sharding(
|
||||||
out_shardings, out_avals, orig_s, orig_aval, are_out_sharding_from_xla):
|
out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla):
|
||||||
out = []
|
out = []
|
||||||
orig_handler = _orig_out_sharding_handlers[type(orig_s)]
|
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
|
||||||
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
|
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
|
||||||
are_out_sharding_from_xla):
|
are_out_sharding_from_xla):
|
||||||
if isinstance(o, sharding_impls.GSPMDSharding):
|
if isinstance(o, sharding_impls.GSPMDSharding):
|
||||||
@ -2302,10 +2320,11 @@ def _get_out_sharding_from_orig_sharding(
|
|||||||
if (orig_aval is not None and out_aval is not None and
|
if (orig_aval is not None and out_aval is not None and
|
||||||
out_aval.ndim == orig_aval.ndim and
|
out_aval.ndim == orig_aval.ndim and
|
||||||
sharding_impls.are_op_shardings_equal(
|
sharding_impls.are_op_shardings_equal(
|
||||||
o._hlo_sharding, orig_s._to_xla_hlo_sharding(orig_aval.ndim))):
|
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim)) and
|
||||||
out.append((orig_s, False))
|
sharding_impls.are_mem_kind_of_shardings_equal(o, orig_in_s)):
|
||||||
|
out.append((orig_in_s, False))
|
||||||
else:
|
else:
|
||||||
out.append((orig_handler(o._hlo_sharding, orig_s), False))
|
out.append((orig_handler(o, orig_in_s), False))
|
||||||
except:
|
except:
|
||||||
out.append((o, from_xla))
|
out.append((o, from_xla))
|
||||||
else:
|
else:
|
||||||
@ -2319,17 +2338,17 @@ def maybe_get_orig_out_sharding(
|
|||||||
return ([o._original_sharding for o in out_shardings],
|
return ([o._original_sharding for o in out_shardings],
|
||||||
(False,) * len(out_shardings))
|
(False,) * len(out_shardings))
|
||||||
|
|
||||||
orig_s = None
|
orig_in_s = None
|
||||||
orig_aval = None
|
orig_aval = None
|
||||||
for i, aval in safe_zip(in_shardings, in_avals):
|
for i, aval in safe_zip(in_shardings, in_avals):
|
||||||
oi = getattr(i, '_original_sharding', None)
|
oi = getattr(i, '_original_sharding', None)
|
||||||
if type(oi) in _orig_out_sharding_handlers:
|
if type(oi) in _orig_out_sharding_handlers:
|
||||||
orig_s = oi
|
orig_in_s = oi
|
||||||
orig_aval = aval
|
orig_aval = aval
|
||||||
break
|
break
|
||||||
if orig_s is not None:
|
if orig_in_s is not None:
|
||||||
return zip(*_get_out_sharding_from_orig_sharding(
|
return zip(*_get_out_sharding_from_orig_sharding(
|
||||||
out_shardings, out_avals, orig_s, orig_aval, are_out_shardings_from_xla))
|
out_shardings, out_avals, orig_in_s, orig_aval, are_out_shardings_from_xla))
|
||||||
|
|
||||||
return out_shardings, are_out_shardings_from_xla
|
return out_shardings, are_out_shardings_from_xla
|
||||||
|
|
||||||
@ -2521,9 +2540,8 @@ class UnloadedMeshExecutable:
|
|||||||
assert mesh is None
|
assert mesh is None
|
||||||
device_assignment = da.device_assignment if isinstance( # type: ignore
|
device_assignment = da.device_assignment if isinstance( # type: ignore
|
||||||
da, _DeviceAssignment) else da
|
da, _DeviceAssignment) else da
|
||||||
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||||
xla_executable, device_assignment, # type: ignore
|
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
|
||||||
len(global_in_avals), len(global_out_avals))
|
|
||||||
orig_out_shardings = out_shardings
|
orig_out_shardings = out_shardings
|
||||||
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
||||||
@ -2532,9 +2550,10 @@ class UnloadedMeshExecutable:
|
|||||||
out_shardings.append(xla_s)
|
out_shardings.append(xla_s)
|
||||||
are_out_shardings_from_xla.append(True)
|
are_out_shardings_from_xla.append(True)
|
||||||
else:
|
else:
|
||||||
if not op_shardings.are_op_shardings_equal(
|
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||||
xla_s._to_xla_hlo_sharding(aval.ndim), # type: ignore
|
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||||
orig._to_xla_hlo_sharding(aval.ndim)): # type: ignore
|
if (not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
|
||||||
|
not sharding_impls.are_mem_kind_of_shardings_equal(xla_s, orig)): # type: ignore
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||||
"(User sharding)")
|
"(User sharding)")
|
||||||
@ -2823,11 +2842,12 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
|||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def create_mesh_pspec_sharding(
|
def create_mesh_pspec_sharding(
|
||||||
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None
|
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
|
||||||
) -> sharding_impls.NamedSharding:
|
memory_kind: Optional[str] = None) -> sharding_impls.NamedSharding:
|
||||||
if pspec is None:
|
if pspec is None:
|
||||||
pspec, parsed_pspec = PartitionSpec(), None
|
pspec, parsed_pspec = PartitionSpec(), None
|
||||||
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec)
|
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
|
||||||
|
memory_kind=memory_kind)
|
||||||
|
|
||||||
|
|
||||||
def check_device_backend_on_shardings(shardings) -> bool:
|
def check_device_backend_on_shardings(shardings) -> bool:
|
||||||
@ -2852,6 +2872,12 @@ def check_gda_or_array_xla_sharding_match(
|
|||||||
if not isinstance(arg, ArrayImpl):
|
if not isinstance(arg, ArrayImpl):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Raise memory kind mismatch error even if the arg is uncommitted.
|
||||||
|
if not sharding_impls.are_mem_kind_of_shardings_equal(arg.sharding, xs):
|
||||||
|
errors.append(
|
||||||
|
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
|
||||||
|
f"arg {name} with shape: {arg.aval.str_short()}")
|
||||||
|
|
||||||
# No need to cache this check since MeshExecutable has a C++ fast path
|
# No need to cache this check since MeshExecutable has a C++ fast path
|
||||||
# for AOT compiled call.
|
# for AOT compiled call.
|
||||||
if (not check_device_backend_on_shardings([xs]) and
|
if (not check_device_backend_on_shardings([xs]) and
|
||||||
|
@ -1741,9 +1741,11 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
|||||||
args_flat):
|
args_flat):
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
|
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
|
||||||
if not op_shardings.are_op_shardings_equal(
|
if (not op_shardings.are_op_shardings_equal(
|
||||||
in_sharding._to_xla_hlo_sharding(ndim),
|
in_sharding._to_xla_hlo_sharding(ndim),
|
||||||
xmap_sharding._to_xla_hlo_sharding(ndim)):
|
xmap_sharding._to_xla_hlo_sharding(ndim)) or
|
||||||
|
not sharding_impls.are_mem_kind_of_shardings_equal(
|
||||||
|
in_sharding, xmap_sharding)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got an input {arr_flavor} to xmap with different partitioning than "
|
f"Got an input {arr_flavor} to xmap with different partitioning than "
|
||||||
"specified in xmap. The partitioning must match. "
|
"specified in xmap. The partitioning must match. "
|
||||||
|
@ -59,7 +59,8 @@ from jax._src.sharding_impls import (
|
|||||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||||
AUTO, UNSPECIFIED, UnspecifiedValue,
|
AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||||
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
||||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding,
|
||||||
|
are_mem_kind_of_shardings_equal)
|
||||||
from jax._src.traceback_util import api_boundary
|
from jax._src.traceback_util import api_boundary
|
||||||
from jax._src.tree_util import (
|
from jax._src.tree_util import (
|
||||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||||
@ -261,7 +262,6 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
|||||||
donate_argnums, tree_util.default_registry, # type: ignore
|
donate_argnums, tree_util.default_registry, # type: ignore
|
||||||
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||||
cpp_pjitted_f._fun = fun
|
cpp_pjitted_f._fun = fun
|
||||||
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
|
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
|
||||||
@ -1092,6 +1092,14 @@ def _resolve_in_shardings(
|
|||||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
|
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
|
||||||
f'Got arg shape: {arg.shape}, arg value: {arg}')
|
f'Got arg shape: {arg.shape}, arg value: {arg}')
|
||||||
if not is_unspecified(arg_s):
|
if not is_unspecified(arg_s):
|
||||||
|
# jax.jit does not allow resharding across different memory kinds even
|
||||||
|
# if the argument is uncommitted. Use jax.device_put for those cases,
|
||||||
|
# either outside or inside jax.jit.
|
||||||
|
if not are_mem_kind_of_shardings_equal(pjit_in_s, arg_s): # type: ignore
|
||||||
|
raise ValueError(
|
||||||
|
'Memory kinds passed to jax.jit does not match memory kind on the'
|
||||||
|
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
|
||||||
|
f'arg memory kind: {arg_s.memory_kind} for arg shape: {arg.shape}') # type: ignore
|
||||||
if (committed and
|
if (committed and
|
||||||
not isinstance(arg_s, PmapSharding) and
|
not isinstance(arg_s, PmapSharding) and
|
||||||
not op_shardings.are_op_shardings_equal(
|
not op_shardings.are_op_shardings_equal(
|
||||||
@ -1232,8 +1240,9 @@ class SameDeviceAssignmentTuple:
|
|||||||
s = getattr(s, "_original_sharding", s)
|
s = getattr(s, "_original_sharding", s)
|
||||||
o = getattr(o, "_original_sharding", o)
|
o = getattr(o, "_original_sharding", o)
|
||||||
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding):
|
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding):
|
||||||
eq.append(op_shardings.are_op_shardings_equal(
|
eq.append(
|
||||||
s._hlo_sharding, o._hlo_sharding))
|
op_shardings.are_op_shardings_equal(s._hlo_sharding, o._hlo_sharding)
|
||||||
|
and are_mem_kind_of_shardings_equal(s, o))
|
||||||
else:
|
else:
|
||||||
eq.append(s == o)
|
eq.append(s == o)
|
||||||
return all(eq) and self.device_assignment == other.device_assignment
|
return all(eq) and self.device_assignment == other.device_assignment
|
||||||
@ -1944,7 +1953,8 @@ def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int,
|
|||||||
device_or_backend_set: bool = False) -> GSPMDSharding:
|
device_or_backend_set: bool = False) -> GSPMDSharding:
|
||||||
if isinstance(s, GSPMDSharding):
|
if isinstance(s, GSPMDSharding):
|
||||||
return s
|
return s
|
||||||
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim))
|
gs = GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||||
|
memory_kind=s.memory_kind)
|
||||||
gs._original_sharding = s
|
gs._original_sharding = s
|
||||||
if device_or_backend_set:
|
if device_or_backend_set:
|
||||||
gs._original_sharding._device_backend = device_or_backend_set
|
gs._original_sharding._device_backend = device_or_backend_set
|
||||||
|
@ -93,6 +93,10 @@ class Sharding:
|
|||||||
"""Returns the memory kind of the sharding."""
|
"""Returns the memory kind of the sharding."""
|
||||||
raise NotImplementedError('Subclasses should implement this method.')
|
raise NotImplementedError('Subclasses should implement this method.')
|
||||||
|
|
||||||
|
def with_memory_kind(self, kind: str) -> Sharding:
|
||||||
|
"""Returns a new Sharding instance with the specified memory kind."""
|
||||||
|
raise NotImplementedError('Subclasses should implement this method')
|
||||||
|
|
||||||
#############################################################################
|
#############################################################################
|
||||||
# Default implementations below that all subclasses will inherit.
|
# Default implementations below that all subclasses will inherit.
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ import enum
|
|||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
from typing import Any, NamedTuple, Union, cast
|
from typing import Any, NamedTuple, Union, cast, Optional
|
||||||
|
|
||||||
from jax._src import mesh as mesh_lib
|
from jax._src import mesh as mesh_lib
|
||||||
from jax._src.op_shardings import (
|
from jax._src.op_shardings import (
|
||||||
@ -109,7 +109,7 @@ class XLACompatibleSharding(sharding.Sharding):
|
|||||||
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
||||||
other._to_xla_hlo_sharding(ndim))
|
other._to_xla_hlo_sharding(ndim))
|
||||||
and self._device_assignment == other._device_assignment and
|
and self._device_assignment == other._device_assignment and
|
||||||
self.memory_kind == other.memory_kind)
|
are_mem_kind_of_shardings_equal(self, other))
|
||||||
# NotImplementedError is raised by PmapSharding because it can't lower
|
# NotImplementedError is raised by PmapSharding because it can't lower
|
||||||
# to OpSharding. So if `other` is a PmapSharding, default to a strict
|
# to OpSharding. So if `other` is a PmapSharding, default to a strict
|
||||||
# equality check.
|
# equality check.
|
||||||
@ -154,6 +154,51 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# This is an optimization to get the memory kinds associated with the local
|
||||||
|
# devices. This is because in McJAX, checking if the memory kind input by user
|
||||||
|
# is correct requires doing `local_devices()[0].memory(inp)` which is expensive
|
||||||
|
# because calculating the local devices is expensive. So cache on xc.Client and
|
||||||
|
# find all the memories associated only once since the client does not change.
|
||||||
|
@functools.lru_cache
|
||||||
|
def _mem_kinds(client: xc.Client) -> set[str]:
|
||||||
|
return set(m.kind for m in client.local_devices()[0].addressable_memories())
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def _default_mem_kind(client: xc.Client) -> str | None:
|
||||||
|
try:
|
||||||
|
return client.local_devices()[0].default_memory().kind
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _check_mem_kind(device: xc.Device, mk):
|
||||||
|
mem_kinds = _mem_kinds(device.client)
|
||||||
|
if mk not in mem_kinds:
|
||||||
|
raise ValueError(
|
||||||
|
f'Could not find memory addressable by device {device.device_kind}.'
|
||||||
|
f' Device {device.device_kind} can address the following memory kinds:'
|
||||||
|
f' {mem_kinds}. Got memory kind: {mk}')
|
||||||
|
|
||||||
|
|
||||||
|
def get_canonicalized_memory_kind(s: XLACompatibleSharding) -> Optional[str]:
|
||||||
|
# TODO(yashkatariya): Remove try;except when CPU and GPU support memories.
|
||||||
|
try:
|
||||||
|
client = s._device_assignment[0].client
|
||||||
|
return _default_mem_kind(client) if s.memory_kind is None else s.memory_kind
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO(yashkatariya): Remove this when the canonicalization happens in __init__
|
||||||
|
# of Shardings. That can be done after OSS support is also added for memories.
|
||||||
|
def are_mem_kind_of_shardings_equal(s1: XLACompatibleSharding,
|
||||||
|
s2: XLACompatibleSharding) -> bool:
|
||||||
|
if s1.memory_kind is None and s2.memory_kind is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
mk1 = get_canonicalized_memory_kind(s1)
|
||||||
|
mk2 = get_canonicalized_memory_kind(s2)
|
||||||
|
return mk1 == mk2
|
||||||
|
|
||||||
|
|
||||||
@use_cpp_class(xc.NamedSharding)
|
@use_cpp_class(xc.NamedSharding)
|
||||||
class NamedSharding(XLACompatibleSharding):
|
class NamedSharding(XLACompatibleSharding):
|
||||||
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
||||||
@ -207,16 +252,13 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
self._preprocess()
|
self._preprocess()
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (
|
return (type(self), (self.mesh, self.spec),
|
||||||
type(self),
|
{'memory_kind': self.memory_kind})
|
||||||
(self.mesh, self.spec),
|
|
||||||
{'memory_kind': self.memory_kind},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _preprocess(self):
|
def _preprocess(self):
|
||||||
if self.memory_kind is not None:
|
if self.memory_kind is not None:
|
||||||
# Will error if memory_kind does not exist on the device.
|
# Will error if memory_kind does not exist on the device.
|
||||||
self.mesh.devices.flat[0].memory(self.memory_kind)
|
_check_mem_kind(self.mesh.devices.flat[0], self.memory_kind)
|
||||||
|
|
||||||
# This split exists because you can pass `_parsed_pspec` that has been
|
# This split exists because you can pass `_parsed_pspec` that has been
|
||||||
# modified from the original. For example: Adding extra dimension to
|
# modified from the original. For example: Adding extra dimension to
|
||||||
@ -239,7 +281,8 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
if not hasattr(self, '_hash'):
|
if not hasattr(self, '_hash'):
|
||||||
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
|
self._hash = hash((self.mesh, get_canonicalized_memory_kind(self),
|
||||||
|
self._parsed_pspec))
|
||||||
return self._hash
|
return self._hash
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -248,11 +291,11 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
if id(self) == id(other):
|
if id(self) == id(other):
|
||||||
return True
|
return True
|
||||||
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
|
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
|
||||||
if (id(self.mesh) == id(other.mesh) and
|
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
|
||||||
self.memory_kind == other.memory_kind and parsed_pspec_equal):
|
if (id(self.mesh) == id(other.mesh) and mem_kind_equal and
|
||||||
|
parsed_pspec_equal):
|
||||||
return True
|
return True
|
||||||
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind
|
return self.mesh == other.mesh and mem_kind_equal and parsed_pspec_equal
|
||||||
and parsed_pspec_equal)
|
|
||||||
|
|
||||||
def is_compatible_aval(self, aval_shape: Shape):
|
def is_compatible_aval(self, aval_shape: Shape):
|
||||||
assert self._parsed_pspec is not None
|
assert self._parsed_pspec is not None
|
||||||
@ -267,7 +310,7 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
|
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
|
||||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
|
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_set(self) -> set[Device]:
|
def device_set(self) -> set[Device]:
|
||||||
@ -300,6 +343,9 @@ class NamedSharding(XLACompatibleSharding):
|
|||||||
num_partitions *= mesh_shape[name]
|
num_partitions *= mesh_shape[name]
|
||||||
return num_partitions == 1
|
return num_partitions == 1
|
||||||
|
|
||||||
|
def with_memory_kind(self, kind: str) -> NamedSharding:
|
||||||
|
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
|
||||||
|
|
||||||
def _get_sharding_spec(self, num_dimensions, axis_ctx):
|
def _get_sharding_spec(self, num_dimensions, axis_ctx):
|
||||||
assert self._parsed_pspec is not None
|
assert self._parsed_pspec is not None
|
||||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||||
@ -362,7 +408,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
if not hasattr(self, '_hash'):
|
if not hasattr(self, '_hash'):
|
||||||
self._hash = hash((self._device, self._memory_kind))
|
self._hash = hash((self._device, get_canonicalized_memory_kind(self)))
|
||||||
return self._hash
|
return self._hash
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -371,7 +417,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
|||||||
if id(self) == id(other):
|
if id(self) == id(other):
|
||||||
return True
|
return True
|
||||||
return (self._device == other._device and
|
return (self._device == other._device and
|
||||||
self._memory_kind == other._memory_kind)
|
are_mem_kind_of_shardings_equal(self, other))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_set(self) -> set[Device]:
|
def device_set(self) -> set[Device]:
|
||||||
@ -381,6 +427,9 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
|||||||
def memory_kind(self) -> str | None:
|
def memory_kind(self) -> str | None:
|
||||||
return self._memory_kind
|
return self._memory_kind
|
||||||
|
|
||||||
|
def with_memory_kind(self, kind: str) -> SingleDeviceSharding:
|
||||||
|
return SingleDeviceSharding(self._device, memory_kind=kind)
|
||||||
|
|
||||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
|
||||||
return {self._device: (slice(None),) * len(global_shape)}
|
return {self._device: (slice(None),) * len(global_shape)}
|
||||||
|
|
||||||
@ -493,6 +542,9 @@ class PmapSharding(XLACompatibleSharding):
|
|||||||
def memory_kind(self):
|
def memory_kind(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def with_memory_kind(self, kind: str):
|
||||||
|
return NotImplementedError("pmap does not support memories.")
|
||||||
|
|
||||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||||
raise NotImplementedError("pmap doesn't use OpSharding.")
|
raise NotImplementedError("pmap doesn't use OpSharding.")
|
||||||
|
|
||||||
@ -524,13 +576,15 @@ class PmapSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
|
|
||||||
def _op_sharding_to_pos_sharding(
|
def _op_sharding_to_pos_sharding(
|
||||||
op_sharding: xc.OpSharding | xc.HloSharding,
|
op_sharding: Union[xc.OpSharding, xc.HloSharding],
|
||||||
device_assignment: Sequence[xc.Device]) -> PositionalSharding:
|
device_assignment: Sequence[xc.Device],
|
||||||
|
memory_kind: Optional[str] = None) -> PositionalSharding:
|
||||||
if isinstance(op_sharding, xc.HloSharding):
|
if isinstance(op_sharding, xc.HloSharding):
|
||||||
op_sharding = op_sharding.to_proto() # type: ignore
|
op_sharding = op_sharding.to_proto() # type: ignore
|
||||||
|
|
||||||
if op_sharding.type == xc.OpSharding.Type.REPLICATED:
|
if op_sharding.type == xc.OpSharding.Type.REPLICATED:
|
||||||
return PositionalSharding(device_assignment).replicate()
|
return PositionalSharding(
|
||||||
|
device_assignment, memory_kind=memory_kind).replicate()
|
||||||
|
|
||||||
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
|
if op_sharding.last_tile_dims == [xc.OpSharding.Type.REPLICATED]:
|
||||||
replicate_on_last_tile_dim = True
|
replicate_on_last_tile_dim = True
|
||||||
@ -543,7 +597,11 @@ def _op_sharding_to_pos_sharding(
|
|||||||
name = device_assignment[0].platform.upper()
|
name = device_assignment[0].platform.upper()
|
||||||
ids = np.array([DeviceIdSet(name, i)
|
ids = np.array([DeviceIdSet(name, i)
|
||||||
for i in op_sharding.tile_assignment_devices])
|
for i in op_sharding.tile_assignment_devices])
|
||||||
p = PositionalSharding._remake(tuple(device_assignment), ids)
|
if memory_kind is not None:
|
||||||
|
# Will error if memory_kind does not exist on the device.
|
||||||
|
_check_mem_kind(device_assignment[0], memory_kind)
|
||||||
|
p = PositionalSharding._remake(tuple(device_assignment), ids,
|
||||||
|
memory_kind=memory_kind)
|
||||||
p = p.reshape(op_sharding.tile_assignment_dimensions)
|
p = p.reshape(op_sharding.tile_assignment_dimensions)
|
||||||
if replicate_on_last_tile_dim:
|
if replicate_on_last_tile_dim:
|
||||||
p = p.replicate(-1, keepdims=False)
|
p = p.replicate(-1, keepdims=False)
|
||||||
@ -569,7 +627,7 @@ class PositionalSharding(XLACompatibleSharding):
|
|||||||
dtype='object').reshape(devices.shape)
|
dtype='object').reshape(devices.shape)
|
||||||
if self._memory_kind is not None:
|
if self._memory_kind is not None:
|
||||||
# Will error if memory_kind does not exist on the device.
|
# Will error if memory_kind does not exist on the device.
|
||||||
self._devices[0].memory(self._memory_kind)
|
_check_mem_kind(self._devices[0], self._memory_kind)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
@ -615,7 +673,7 @@ class PositionalSharding(XLACompatibleSharding):
|
|||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
if not hasattr(self, '_hash'):
|
if not hasattr(self, '_hash'):
|
||||||
self._hash = hash((self._devices, self._memory_kind))
|
self._hash = hash((self._devices, get_canonicalized_memory_kind(self)))
|
||||||
return self._hash
|
return self._hash
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
@ -624,11 +682,11 @@ class PositionalSharding(XLACompatibleSharding):
|
|||||||
if id(self) == id(other):
|
if id(self) == id(other):
|
||||||
return True
|
return True
|
||||||
all_ids_equal = np.array_equal(self._ids,other._ids)
|
all_ids_equal = np.array_equal(self._ids,other._ids)
|
||||||
if (id(self._devices) == id(other._devices) and
|
mem_kind_equal = are_mem_kind_of_shardings_equal(self, other)
|
||||||
self._memory_kind == other._memory_kind and all_ids_equal):
|
if (id(self._devices) == id(other._devices) and mem_kind_equal and
|
||||||
|
all_ids_equal):
|
||||||
return True
|
return True
|
||||||
return (self._devices == other._devices and
|
return self._devices == other._devices and mem_kind_equal and all_ids_equal
|
||||||
self._memory_kind == other._memory_kind and all_ids_equal)
|
|
||||||
|
|
||||||
# Sharding interface
|
# Sharding interface
|
||||||
|
|
||||||
@ -640,6 +698,9 @@ class PositionalSharding(XLACompatibleSharding):
|
|||||||
def memory_kind(self) -> str | None:
|
def memory_kind(self) -> str | None:
|
||||||
return self._memory_kind
|
return self._memory_kind
|
||||||
|
|
||||||
|
def with_memory_kind(self, kind: str) -> PositionalSharding:
|
||||||
|
return PositionalSharding(self._devices, memory_kind=kind)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def is_fully_replicated(self) -> bool:
|
def is_fully_replicated(self) -> bool:
|
||||||
return self.shape == (1,) * self.ndim
|
return self.shape == (1,) * self.ndim
|
||||||
@ -717,12 +778,14 @@ class GSPMDSharding(XLACompatibleSharding):
|
|||||||
self._hlo_sharding = op_sharding
|
self._hlo_sharding = op_sharding
|
||||||
self._memory_kind = memory_kind
|
self._memory_kind = memory_kind
|
||||||
|
|
||||||
|
def _preprocess(self):
|
||||||
|
if self.memory_kind is not None:
|
||||||
|
# Will error if memory_kind does not exist on the device.
|
||||||
|
_check_mem_kind(self._devices[0], self.memory_kind)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (
|
return (type(self), (self._devices, self._hlo_sharding.to_proto()),
|
||||||
type(self),
|
{'memory_kind': self._memory_kind})
|
||||||
(self._devices, self._hlo_sharding.to_proto()),
|
|
||||||
{'memory_kind': self._memory_kind},
|
|
||||||
)
|
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def _hlo_sharding_hash(self):
|
def _hlo_sharding_hash(self):
|
||||||
@ -735,12 +798,12 @@ class GSPMDSharding(XLACompatibleSharding):
|
|||||||
return True
|
return True
|
||||||
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
||||||
and self._devices == other._devices and
|
and self._devices == other._devices and
|
||||||
self._memory_kind == other._memory_kind)
|
are_mem_kind_of_shardings_equal(self, other))
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
if not hasattr(self, '_hash'):
|
if not hasattr(self, '_hash'):
|
||||||
self._hash = hash((self._devices, self._hlo_sharding_hash,
|
self._hash = hash((self._devices, self._hlo_sharding_hash,
|
||||||
self._memory_kind))
|
get_canonicalized_memory_kind(self)))
|
||||||
return self._hash
|
return self._hash
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -763,6 +826,9 @@ class GSPMDSharding(XLACompatibleSharding):
|
|||||||
def memory_kind(self) -> str | None:
|
def memory_kind(self) -> str | None:
|
||||||
return self._memory_kind
|
return self._memory_kind
|
||||||
|
|
||||||
|
def with_memory_kind(self, kind: str) -> GSPMDSharding:
|
||||||
|
return GSPMDSharding(self._devices, self._hlo_sharding, memory_kind=kind)
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=4096)
|
@functools.lru_cache(maxsize=4096)
|
||||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||||
self.shard_shape(global_shape) # raises a good error message
|
self.shard_shape(global_shape) # raises a good error message
|
||||||
|
Loading…
x
Reference in New Issue
Block a user