mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
This commit is contained in:
parent
8a1445a038
commit
1273028018
@ -789,6 +789,7 @@ pytype_strict_library(
|
||||
srcs = ["_src/sharding_impls.py"],
|
||||
deps = [
|
||||
":config",
|
||||
":core",
|
||||
":mesh",
|
||||
":op_shardings",
|
||||
":partition_spec",
|
||||
|
@ -48,8 +48,8 @@ from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
|
||||
GSPMDSharding, TransferToMemoryKind)
|
||||
SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
|
||||
GSPMDSharding, TransferToMemoryKind, is_single_device_sharding)
|
||||
from jax._src.layout import Layout, DeviceLocalLayout
|
||||
|
||||
|
||||
@ -163,12 +163,6 @@ def wait_for_tokens():
|
||||
runtime_tokens.block_until_ready()
|
||||
|
||||
|
||||
def is_single_device_sharding(sharding: Sharding) -> bool:
|
||||
# Special case PmapSharding here because PmapSharding maps away an axis
|
||||
# and needs to be handled separately.test_pjit_single_device_sharding_add
|
||||
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
|
||||
if _on_exit:
|
||||
|
@ -20,6 +20,7 @@ from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
@ -80,7 +81,7 @@ class EArray(basearray.Array):
|
||||
@property
|
||||
def sharding(self):
|
||||
phys_sharding = self._data.sharding
|
||||
return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding)
|
||||
return sharding_impls.logical_sharding(self.aval, phys_sharding)
|
||||
|
||||
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
|
||||
|
||||
@ -99,7 +100,7 @@ class EArray(basearray.Array):
|
||||
|
||||
def _earray_shard_arg_handler(x, sharding):
|
||||
arr = x._data
|
||||
phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding)
|
||||
phys_sharding = sharding_impls.physical_sharding(x.aval, sharding)
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
|
||||
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
||||
|
||||
|
@ -818,7 +818,7 @@ def _to_physical_op_sharding(
|
||||
return _to_physical_op_sharding(aval.inner_aval, sharding)
|
||||
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
sharding = aval.dtype._rules.physical_sharding(aval, sharding)
|
||||
sharding = sharding_impls.physical_sharding(aval, sharding)
|
||||
aval = core.physical_aval(aval)
|
||||
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
|
||||
|
||||
@ -1376,7 +1376,7 @@ def lower_jaxpr_to_fun(
|
||||
|
||||
if ir_arg_shardings is not None and name == "main":
|
||||
flat_args = [
|
||||
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error
|
||||
replicate_trailing_dims(entry_lowering_ctx, o, a)
|
||||
if (a is not core.abstract_token and
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error
|
||||
for o, s, a in zip(flat_args, ir_arg_shardings, input_avals)
|
||||
@ -1417,7 +1417,7 @@ def lower_jaxpr_to_fun(
|
||||
|
||||
if ir_result_shardings is not None and name == "main":
|
||||
flat_outputs = [
|
||||
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error
|
||||
replicate_trailing_dims(entry_lowering_ctx, o, a)
|
||||
if (a is not core.abstract_token and
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error
|
||||
for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals)
|
||||
@ -1441,6 +1441,19 @@ def wrap_with_memory_kind(
|
||||
return op.result
|
||||
|
||||
|
||||
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
# Set the sharding of extended dtypes to be UNCONSTRAINED
|
||||
# (i.e. XLA will choose) on aval.shape.
|
||||
# For the trailing dims i.e. the dimension of key_shape on the base_array,
|
||||
# the sharding is set to be REPLICATED always.
|
||||
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
|
||||
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
|
||||
# The below custom call achieves the sharding like above example.
|
||||
return wrap_with_sharding_op(
|
||||
ctx, val, aval, xc.HloSharding.replicate().to_proto(),
|
||||
unspecified_dims=set(range(aval.ndim)))
|
||||
|
||||
|
||||
def _emit_lowering_rule_as_fun(lowering_rule,
|
||||
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
|
||||
"""Emits the contents of a lowering rule as a private function."""
|
||||
|
@ -2563,20 +2563,10 @@ def _register_out_sharding_handler(
|
||||
) -> None:
|
||||
_orig_out_sharding_handlers[sharding_cls] = handler
|
||||
|
||||
|
||||
def _gspmd_to_named_sharding_via_mesh(
|
||||
out_s: sharding_impls.GSPMDSharding,
|
||||
mesh: Mesh) -> sharding_impls.NamedSharding:
|
||||
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
|
||||
out_s._hlo_sharding, mesh)[0]
|
||||
return create_mesh_pspec_sharding(
|
||||
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
|
||||
out_s.memory_kind)
|
||||
|
||||
def _gspmd_to_named_sharding(
|
||||
out_s: sharding_impls.GSPMDSharding,
|
||||
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
|
||||
return _gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
|
||||
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
|
||||
|
||||
_register_out_sharding_handler(
|
||||
sharding_impls.NamedSharding, _gspmd_to_named_sharding)
|
||||
@ -2793,7 +2783,7 @@ def _maybe_get_and_check_in_shardings(
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
|
||||
@ -2829,7 +2819,7 @@ def _maybe_get_and_check_out_shardings(
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
new_out_shardings.append(xla_s)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
|
||||
@ -3120,7 +3110,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
kept_var_bitvec = [i in self._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
in_shardings = [
|
||||
a.dtype._rules.physical_sharding(a, s)
|
||||
sharding_impls.physical_sharding(a, s)
|
||||
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
||||
else s
|
||||
for s, a in zip(self._in_shardings, self.in_avals)
|
||||
@ -3186,14 +3176,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
return in_shardings, out_shardings, committed, tuple(local_devices)
|
||||
|
||||
|
||||
@util.cache()
|
||||
def create_mesh_pspec_sharding(
|
||||
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
|
||||
memory_kind: str | None = None) -> sharding_impls.NamedSharding:
|
||||
if pspec is None:
|
||||
pspec, parsed_pspec = PartitionSpec(), None
|
||||
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
|
||||
memory_kind=memory_kind)
|
||||
create_mesh_pspec_sharding = sharding_impls.create_mesh_pspec_sharding
|
||||
|
||||
|
||||
def check_device_backend_on_shardings(shardings) -> bool:
|
||||
|
@ -5209,14 +5209,6 @@ class BIntRules:
|
||||
return core.DArray(aval, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def logical_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def physical_sharding(aval, sharding):
|
||||
return sharding
|
||||
|
||||
@staticmethod
|
||||
def convert_from(bint_dtype, other_dtype) -> bool:
|
||||
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
||||
@ -5225,12 +5217,5 @@ class BIntRules:
|
||||
def convert_to(other_dtype, bint_dtype) -> bool:
|
||||
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
|
||||
pass
|
||||
|
||||
core.bint._rules = BIntRules
|
||||
|
@ -269,7 +269,7 @@ def _get_fastpath_data(
|
||||
kept_var_bitvec = [i in executable._kept_var_idx
|
||||
for i in range(len(args_flat))]
|
||||
in_shardings = [
|
||||
a.dtype._rules.physical_sharding(a, s)
|
||||
sharding_impls.physical_sharding(a, s)
|
||||
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
||||
else s
|
||||
for s, a in zip(executable._in_shardings, executable.in_avals)
|
||||
|
126
jax/_src/prng.py
126
jax/_src/prng.py
@ -33,11 +33,9 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util as tree_util_internal
|
||||
from jax._src import typing
|
||||
from jax._src import op_shardings
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.dtypes import float0
|
||||
from jax._src.interpreters import ad
|
||||
@ -53,9 +51,8 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.array_methods import (
|
||||
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, PmapSharding, GSPMDSharding, XLACompatibleSharding)
|
||||
NamedSharding, PmapSharding, physical_sharding, logical_sharding)
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
@ -235,8 +232,7 @@ class PRNGKeyArray(jax.Array):
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
phys_sharding = self._base_array.sharding
|
||||
return KeyTyRules.logical_sharding(self.aval, phys_sharding)
|
||||
return logical_sharding(self.aval, self._base_array.sharding)
|
||||
|
||||
def _is_scalar(self):
|
||||
base_ndim = len(self._impl.key_shape)
|
||||
@ -324,53 +320,6 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape):
|
||||
base_ndim = len(impl.key_shape)
|
||||
return base_arr_shape[:-base_ndim]
|
||||
|
||||
def make_key_array_phys_sharding(aval, sharding):
|
||||
if dispatch.is_single_device_sharding(sharding):
|
||||
return sharding
|
||||
elif isinstance(sharding, PmapSharding):
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape)
|
||||
phys_sharding_spec = sharding_specs.ShardingSpec(
|
||||
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
|
||||
mesh_mapping=sharding.sharding_spec.mesh_mapping)
|
||||
return PmapSharding(devices=sharding.devices,
|
||||
sharding_spec=phys_sharding_spec)
|
||||
elif isinstance(sharding, NamedSharding):
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
trailing_spec = [None] * len(key_shape)
|
||||
return NamedSharding(
|
||||
sharding.mesh,
|
||||
PartitionSpec(*sharding.spec, *trailing_spec))
|
||||
else:
|
||||
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
|
||||
return GSPMDSharding(
|
||||
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
|
||||
|
||||
|
||||
def get_logical_gspmd_sharding(aval, phys_sharding):
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
|
||||
aval.ndim + len(key_shape))
|
||||
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
|
||||
phys_hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
# Create logical sharding by cutting off the replicated trailing dims.
|
||||
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
|
||||
tad = partitions[:-len(key_shape)] + suffix
|
||||
logical_op_sharding.tile_assignment_dimensions = tad
|
||||
return GSPMDSharding(phys_sharding._device_assignment,
|
||||
xc.HloSharding.from_proto(logical_op_sharding))
|
||||
|
||||
|
||||
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
new_op_sharding = hlo_sharding.to_proto().clone()
|
||||
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
tad = partitions + [1] * len(key_shape) + suffix
|
||||
new_op_sharding.tile_assignment_dimensions = tad
|
||||
return xc.HloSharding.from_proto(new_op_sharding)
|
||||
|
||||
|
||||
class KeyTyRules:
|
||||
|
||||
@ -393,32 +342,6 @@ class KeyTyRules:
|
||||
def physical_const(val) -> Array:
|
||||
return val._base_array
|
||||
|
||||
@staticmethod
|
||||
def physical_sharding(
|
||||
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
|
||||
return make_key_array_phys_sharding(aval, sharding)
|
||||
|
||||
@staticmethod
|
||||
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
|
||||
# The trailing dims should always be replicated.
|
||||
aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval)
|
||||
|
||||
if dispatch.is_single_device_sharding(phys_sharding):
|
||||
return phys_sharding
|
||||
elif isinstance(phys_sharding, PmapSharding):
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
logical_sharding_spec = sharding_specs.ShardingSpec(
|
||||
sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)],
|
||||
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
|
||||
return PmapSharding(devices=phys_sharding.devices,
|
||||
sharding_spec=logical_sharding_spec)
|
||||
elif isinstance(phys_sharding, NamedSharding):
|
||||
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
|
||||
return pxla._gspmd_to_named_sharding_via_mesh(
|
||||
logical_gs, phys_sharding.mesh)
|
||||
else:
|
||||
return get_logical_gspmd_sharding(aval, phys_sharding)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
@ -434,7 +357,7 @@ class KeyTyRules:
|
||||
|
||||
# set up a grounded sharding (with a grounded sharding spec)
|
||||
if isinstance(sharding, (PmapSharding, NamedSharding)):
|
||||
phys_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
phys_sharding = physical_sharding(aval, sharding)
|
||||
else:
|
||||
assert False, f'impossible sharding {sharding} in local sharded result handler'
|
||||
|
||||
@ -456,7 +379,7 @@ class KeyTyRules:
|
||||
phys_aval = core.physical_aval(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(aval, out_sharding)
|
||||
phys_sharding = physical_sharding(aval, out_sharding)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
def handler(bufs):
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
|
||||
@ -468,7 +391,7 @@ class KeyTyRules:
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
phys_arrays = [random_unwrap(arr) for arr in arrays]
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
phys_sharding = physical_sharding(aval, sharding)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
phys_result = phys_handler(phys_arrays)
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_result)
|
||||
@ -477,8 +400,9 @@ class KeyTyRules:
|
||||
def device_put_sharded(vals, aval, sharding, devices):
|
||||
physical_aval = core.physical_aval(aval)
|
||||
physical_buffers = tree_util.tree_map(random_unwrap, vals)
|
||||
physical_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices))
|
||||
phys_sharding = physical_sharding(aval, sharding)
|
||||
physical_result = pxla.batched_device_put(physical_aval, phys_sharding,
|
||||
physical_buffers, list(devices))
|
||||
return random_wrap(physical_result, impl=aval.dtype._impl)
|
||||
|
||||
@staticmethod
|
||||
@ -486,37 +410,11 @@ class KeyTyRules:
|
||||
physical_aval = core.physical_aval(aval)
|
||||
assert len(xla.aval_to_xla_shapes(physical_aval)) == 1
|
||||
physical_buf = random_unwrap(val)
|
||||
physical_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
|
||||
phys_sharding = physical_sharding(aval, sharding)
|
||||
physical_result = pxla.batched_device_put(
|
||||
physical_aval, phys_sharding, [physical_buf] * len(devices), devices)
|
||||
return random_wrap(physical_result, impl=aval.dtype._impl)
|
||||
|
||||
@staticmethod
|
||||
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
|
||||
if isinstance(sharding, PmapSharding):
|
||||
return
|
||||
phys_aval = core.physical_aval(aval)
|
||||
hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim)
|
||||
partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s)
|
||||
num_trailing_dims = phys_aval.ndim - aval.ndim
|
||||
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
|
||||
raise AssertionError(
|
||||
"The trailing dims of extended dtypes should be replicated. Got"
|
||||
f" sharding: {sharding}, partitions: {partitions}, "
|
||||
f"num_trailing_dims: {num_trailing_dims}")
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
|
||||
# Set the sharding of extended dtypes to be UNCONSTRAINED
|
||||
# (i.e. XLA will choose) on aval.shape.
|
||||
# For the trailing dims i.e. the dimension of key_shape on the base_array,
|
||||
# the sharding is set to be REPLICATED always.
|
||||
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
|
||||
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
|
||||
# The below custom call achieves the sharding like above example.
|
||||
return mlir.wrap_with_sharding_op(
|
||||
ctx, val, aval, xc.HloSharding.replicate().to_proto(),
|
||||
unspecified_dims=set(range(aval.ndim)))
|
||||
|
||||
@staticmethod
|
||||
def tangent_dtype(_):
|
||||
return dtypes.float0
|
||||
@ -571,7 +469,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
|
||||
arr = x._base_array
|
||||
phys_sharding = make_key_array_phys_sharding(x.aval, sharding)
|
||||
phys_sharding = physical_sharding(x.aval, sharding)
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
|
||||
|
||||
|
||||
|
@ -30,6 +30,7 @@ from jax._src import sharding_specs
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded,
|
||||
is_op_sharding_replicated,
|
||||
@ -1437,3 +1438,111 @@ def num_addressable_indices(
|
||||
})
|
||||
shard_size = tensor_sharding.shard_shape(global_shape)[dim]
|
||||
return shard_size * num_unique_slices
|
||||
|
||||
|
||||
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
|
||||
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
|
||||
new_op_sharding = hlo_sharding.to_proto().clone()
|
||||
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
tad = partitions + [1] * elt_aval.ndim + suffix
|
||||
new_op_sharding.tile_assignment_dimensions = tad
|
||||
return xc.HloSharding.from_proto(new_op_sharding)
|
||||
|
||||
def is_single_device_sharding(sharding: sharding.Sharding) -> bool:
|
||||
# Special case PmapSharding here because PmapSharding maps away an axis
|
||||
# and needs to be handled separately.test_pjit_single_device_sharding_add
|
||||
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
||||
|
||||
def make_key_array_phys_sharding(aval, sharding):
|
||||
if is_single_device_sharding(sharding):
|
||||
return sharding
|
||||
elif isinstance(sharding, PmapSharding):
|
||||
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
|
||||
trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim
|
||||
phys_sharding_spec = sharding_specs.ShardingSpec(
|
||||
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
|
||||
mesh_mapping=sharding.sharding_spec.mesh_mapping)
|
||||
return PmapSharding(devices=sharding.devices,
|
||||
sharding_spec=phys_sharding_spec)
|
||||
elif isinstance(sharding, NamedSharding):
|
||||
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
|
||||
trailing_spec = [None] * elt_aval.ndim
|
||||
return NamedSharding(
|
||||
sharding.mesh,
|
||||
PartitionSpec(*sharding.spec, *trailing_spec))
|
||||
else:
|
||||
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
|
||||
return GSPMDSharding(
|
||||
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
|
||||
|
||||
|
||||
def physical_sharding(
|
||||
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
|
||||
return make_key_array_phys_sharding(aval, sharding)
|
||||
|
||||
|
||||
def get_logical_gspmd_sharding(aval, phys_sharding):
|
||||
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
|
||||
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
|
||||
aval.ndim + elt_aval.ndim)
|
||||
partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
# Create logical sharding by cutting off the replicated trailing dims.
|
||||
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
|
||||
tad = partitions[:-elt_aval.ndim] + suffix
|
||||
logical_op_sharding.tile_assignment_dimensions = tad
|
||||
return GSPMDSharding(phys_sharding._device_assignment,
|
||||
xc.HloSharding.from_proto(logical_op_sharding))
|
||||
|
||||
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
|
||||
if isinstance(sharding, PmapSharding):
|
||||
return
|
||||
phys_aval = core.physical_aval(aval)
|
||||
hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim)
|
||||
partitions, _ = get_num_ways_dim_sharded(hlo_s)
|
||||
num_trailing_dims = phys_aval.ndim - aval.ndim
|
||||
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
|
||||
raise AssertionError(
|
||||
"The trailing dims of extended dtypes should be replicated. Got"
|
||||
f" sharding: {sharding}, partitions: {partitions}, "
|
||||
f"num_trailing_dims: {num_trailing_dims}")
|
||||
|
||||
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
|
||||
# The trailing dims should always be replicated.
|
||||
check_replicated_trailing_dims(phys_sharding, aval)
|
||||
|
||||
if is_single_device_sharding(phys_sharding):
|
||||
return phys_sharding
|
||||
elif isinstance(phys_sharding, PmapSharding):
|
||||
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
|
||||
logical_sharding_spec = sharding_specs.ShardingSpec(
|
||||
sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim],
|
||||
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
|
||||
return PmapSharding(devices=phys_sharding.devices,
|
||||
sharding_spec=logical_sharding_spec)
|
||||
elif isinstance(phys_sharding, NamedSharding):
|
||||
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
|
||||
return _gspmd_to_named_sharding_via_mesh(
|
||||
logical_gs, phys_sharding.mesh)
|
||||
else:
|
||||
return get_logical_gspmd_sharding(aval, phys_sharding)
|
||||
|
||||
|
||||
@util.cache()
|
||||
def create_mesh_pspec_sharding(
|
||||
mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
|
||||
memory_kind: str | None = None) -> NamedSharding:
|
||||
if pspec is None:
|
||||
pspec, parsed_pspec = PartitionSpec(), None
|
||||
return NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
|
||||
memory_kind=memory_kind)
|
||||
|
||||
|
||||
def _gspmd_to_named_sharding_via_mesh(
|
||||
out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
out_s._hlo_sharding, mesh)[0]
|
||||
return create_mesh_pspec_sharding(
|
||||
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
|
||||
out_s.memory_kind)
|
||||
|
@ -653,7 +653,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
|
||||
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
||||
ns = aval_in.dtype._rules.physical_sharding(aval_in, ns)
|
||||
ns = sharding_impls.physical_sharding(aval_in, ns)
|
||||
aval_in = core.physical_aval(aval_in)
|
||||
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
|
||||
unspecified = set(range(aval_in.ndim)) if auto else set()
|
||||
@ -667,7 +667,7 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
|
||||
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
||||
ns = aval_out.dtype._rules.physical_sharding(aval_out, ns)
|
||||
ns = sharding_impls.physical_sharding(aval_out, ns)
|
||||
aval_out = core.physical_aval(aval_out)
|
||||
unspecified = set(range(aval_out.ndim)) if auto else set()
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
|
||||
|
@ -578,15 +578,6 @@ class EArrayTest(jtu.JaxTestCase):
|
||||
def physical_element_aval(foo_dtype):
|
||||
return core.ShapedArray((), dtypes.dtype('float32'))
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val, aval):
|
||||
del ctx, aval
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def logical_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed):
|
||||
phys_sharding = out_sharding # unlike KeyTyRules, assume same shape
|
||||
@ -595,10 +586,6 @@ class EArrayTest(jtu.JaxTestCase):
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
return lambda bufs: earray.EArray(aval, phys_handler(bufs))
|
||||
|
||||
@staticmethod
|
||||
def physical_sharding(aval, sharding):
|
||||
return sharding # unlike KeyTyRules, assume same shape
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FooTy(dtypes.ExtendedDType):
|
||||
name: str = 'foo'
|
||||
|
@ -3315,14 +3315,6 @@ class FooTyRules:
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
return core.ShapedArray((2,), jnp.dtype('uint32'))
|
||||
|
||||
@staticmethod
|
||||
def logical_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def physical_sharding(aval, sharding):
|
||||
return sharding
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
@ -3341,14 +3333,6 @@ class FooTyRules:
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val, aval):
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
|
||||
pass
|
||||
|
||||
|
||||
class FooTy(dtypes.ExtendedDType):
|
||||
type = dtypes.extended
|
||||
|
Loading…
x
Reference in New Issue
Block a user