Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.

PiperOrigin-RevId: 639920049
This commit is contained in:
Yash Katariya 2024-06-03 14:52:08 -07:00 committed by jax authors
parent 8a1445a038
commit 1273028018
12 changed files with 151 additions and 196 deletions

View File

@ -789,6 +789,7 @@ pytype_strict_library(
srcs = ["_src/sharding_impls.py"],
deps = [
":config",
":core",
":mesh",
":op_shardings",
":partition_spec",

View File

@ -48,8 +48,8 @@ from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
GSPMDSharding, TransferToMemoryKind)
SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
GSPMDSharding, TransferToMemoryKind, is_single_device_sharding)
from jax._src.layout import Layout, DeviceLocalLayout
@ -163,12 +163,6 @@ def wait_for_tokens():
runtime_tokens.block_until_ready()
def is_single_device_sharding(sharding: Sharding) -> bool:
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.test_pjit_single_device_sharding_add
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
@contextlib.contextmanager
def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
if _on_exit:

View File

@ -20,6 +20,7 @@ from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import tree_util
from jax._src import sharding_impls
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.util import safe_zip, safe_map
@ -80,7 +81,7 @@ class EArray(basearray.Array):
@property
def sharding(self):
phys_sharding = self._data.sharding
return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding)
return sharding_impls.logical_sharding(self.aval, phys_sharding)
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
@ -99,7 +100,7 @@ class EArray(basearray.Array):
def _earray_shard_arg_handler(x, sharding):
arr = x._data
phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding)
phys_sharding = sharding_impls.physical_sharding(x.aval, sharding)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler

View File

@ -818,7 +818,7 @@ def _to_physical_op_sharding(
return _to_physical_op_sharding(aval.inner_aval, sharding)
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
if dtypes.issubdtype(aval.dtype, dtypes.extended):
sharding = aval.dtype._rules.physical_sharding(aval, sharding)
sharding = sharding_impls.physical_sharding(aval, sharding)
aval = core.physical_aval(aval)
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
@ -1376,7 +1376,7 @@ def lower_jaxpr_to_fun(
if ir_arg_shardings is not None and name == "main":
flat_args = [
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error
replicate_trailing_dims(entry_lowering_ctx, o, a)
if (a is not core.abstract_token and
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error
for o, s, a in zip(flat_args, ir_arg_shardings, input_avals)
@ -1417,7 +1417,7 @@ def lower_jaxpr_to_fun(
if ir_result_shardings is not None and name == "main":
flat_outputs = [
a.dtype._rules.replicate_trailing_dims(entry_lowering_ctx, o, a) # pytype: disable=attribute-error
replicate_trailing_dims(entry_lowering_ctx, o, a)
if (a is not core.abstract_token and
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error
for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals)
@ -1441,6 +1441,19 @@ def wrap_with_memory_kind(
return op.result
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# Set the sharding of extended dtypes to be UNCONSTRAINED
# (i.e. XLA will choose) on aval.shape.
# For the trailing dims i.e. the dimension of key_shape on the base_array,
# the sharding is set to be REPLICATED always.
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
return wrap_with_sharding_op(
ctx, val, aval, xc.HloSharding.replicate().to_proto(),
unspecified_dims=set(range(aval.ndim)))
def _emit_lowering_rule_as_fun(lowering_rule,
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
"""Emits the contents of a lowering rule as a private function."""

View File

@ -2563,20 +2563,10 @@ def _register_out_sharding_handler(
) -> None:
_orig_out_sharding_handlers[sharding_cls] = handler
def _gspmd_to_named_sharding_via_mesh(
out_s: sharding_impls.GSPMDSharding,
mesh: Mesh) -> sharding_impls.NamedSharding:
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
out_s._hlo_sharding, mesh)[0]
return create_mesh_pspec_sharding(
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)
def _gspmd_to_named_sharding(
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
return _gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
_register_out_sharding_handler(
sharding_impls.NamedSharding, _gspmd_to_named_sharding)
@ -2793,7 +2783,7 @@ def _maybe_get_and_check_in_shardings(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
xla_s = sharding_impls.logical_sharding(aval, xla_s)
new_in_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
@ -2829,7 +2819,7 @@ def _maybe_get_and_check_out_shardings(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
xla_s = sharding_impls.logical_sharding(aval, xla_s)
new_out_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
@ -3120,7 +3110,7 @@ class MeshExecutable(stages.XlaExecutable):
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args_flat))]
in_shardings = [
a.dtype._rules.physical_sharding(a, s)
sharding_impls.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(self._in_shardings, self.in_avals)
@ -3186,14 +3176,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
return in_shardings, out_shardings, committed, tuple(local_devices)
@util.cache()
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
memory_kind: str | None = None) -> sharding_impls.NamedSharding:
if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None
return sharding_impls.NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
memory_kind=memory_kind)
create_mesh_pspec_sharding = sharding_impls.create_mesh_pspec_sharding
def check_device_backend_on_shardings(shardings) -> bool:

View File

@ -5209,14 +5209,6 @@ class BIntRules:
return core.DArray(aval, phys_handler(bufs))
return handler
@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def physical_sharding(aval, sharding):
return sharding
@staticmethod
def convert_from(bint_dtype, other_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
@ -5225,12 +5217,5 @@ class BIntRules:
def convert_to(other_dtype, bint_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
@staticmethod
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
return val
@staticmethod
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
pass
core.bint._rules = BIntRules

View File

@ -269,7 +269,7 @@ def _get_fastpath_data(
kept_var_bitvec = [i in executable._kept_var_idx
for i in range(len(args_flat))]
in_shardings = [
a.dtype._rules.physical_sharding(a, s)
sharding_impls.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(executable._in_shardings, executable.in_avals)

View File

@ -33,11 +33,9 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import sharding_specs
from jax._src import source_info_util
from jax._src import tree_util as tree_util_internal
from jax._src import typing
from jax._src import op_shardings
from jax._src.api import jit, vmap
from jax._src.dtypes import float0
from jax._src.interpreters import ad
@ -53,9 +51,8 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.array_methods import (
_array_operators, _set_array_base_attributes, _IndexUpdateHelper)
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
NamedSharding, PmapSharding, GSPMDSharding, XLACompatibleSharding)
NamedSharding, PmapSharding, physical_sharding, logical_sharding)
from jax._src.typing import Array
from jax._src.util import safe_map, safe_zip
@ -235,8 +232,7 @@ class PRNGKeyArray(jax.Array):
@property
def sharding(self):
phys_sharding = self._base_array.sharding
return KeyTyRules.logical_sharding(self.aval, phys_sharding)
return logical_sharding(self.aval, self._base_array.sharding)
def _is_scalar(self):
base_ndim = len(self._impl.key_shape)
@ -324,53 +320,6 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape):
base_ndim = len(impl.key_shape)
return base_arr_shape[:-base_ndim]
def make_key_array_phys_sharding(aval, sharding):
if dispatch.is_single_device_sharding(sharding):
return sharding
elif isinstance(sharding, PmapSharding):
key_shape = aval.dtype._impl.key_shape
trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape)
phys_sharding_spec = sharding_specs.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
elif isinstance(sharding, NamedSharding):
key_shape = aval.dtype._impl.key_shape
trailing_spec = [None] * len(key_shape)
return NamedSharding(
sharding.mesh,
PartitionSpec(*sharding.spec, *trailing_spec))
else:
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
return GSPMDSharding(
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
def get_logical_gspmd_sharding(aval, phys_sharding):
key_shape = aval.dtype._impl.key_shape
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape))
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-len(key_shape)] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
new_op_sharding = hlo_sharding.to_proto().clone()
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)
class KeyTyRules:
@ -393,32 +342,6 @@ class KeyTyRules:
def physical_const(val) -> Array:
return val._base_array
@staticmethod
def physical_sharding(
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
return make_key_array_phys_sharding(aval, sharding)
@staticmethod
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
# The trailing dims should always be replicated.
aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval)
if dispatch.is_single_device_sharding(phys_sharding):
return phys_sharding
elif isinstance(phys_sharding, PmapSharding):
key_shape = aval.dtype._impl.key_shape
logical_sharding_spec = sharding_specs.ShardingSpec(
sharding=phys_sharding.sharding_spec.sharding[:-len(key_shape)],
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=phys_sharding.devices,
sharding_spec=logical_sharding_spec)
elif isinstance(phys_sharding, NamedSharding):
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
return pxla._gspmd_to_named_sharding_via_mesh(
logical_gs, phys_sharding.mesh)
else:
return get_logical_gspmd_sharding(aval, phys_sharding)
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
@ -434,7 +357,7 @@ class KeyTyRules:
# set up a grounded sharding (with a grounded sharding spec)
if isinstance(sharding, (PmapSharding, NamedSharding)):
phys_sharding = make_key_array_phys_sharding(aval, sharding)
phys_sharding = physical_sharding(aval, sharding)
else:
assert False, f'impossible sharding {sharding} in local sharded result handler'
@ -456,7 +379,7 @@ class KeyTyRules:
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_sharding = make_key_array_phys_sharding(aval, out_sharding)
phys_sharding = physical_sharding(aval, out_sharding)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
def handler(bufs):
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
@ -468,7 +391,7 @@ class KeyTyRules:
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_arrays = [random_unwrap(arr) for arr in arrays]
phys_sharding = make_key_array_phys_sharding(aval, sharding)
phys_sharding = physical_sharding(aval, sharding)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
phys_result = phys_handler(phys_arrays)
return PRNGKeyArray(aval.dtype._impl, phys_result)
@ -477,8 +400,9 @@ class KeyTyRules:
def device_put_sharded(vals, aval, sharding, devices):
physical_aval = core.physical_aval(aval)
physical_buffers = tree_util.tree_map(random_unwrap, vals)
physical_sharding = make_key_array_phys_sharding(aval, sharding)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices))
phys_sharding = physical_sharding(aval, sharding)
physical_result = pxla.batched_device_put(physical_aval, phys_sharding,
physical_buffers, list(devices))
return random_wrap(physical_result, impl=aval.dtype._impl)
@staticmethod
@ -486,37 +410,11 @@ class KeyTyRules:
physical_aval = core.physical_aval(aval)
assert len(xla.aval_to_xla_shapes(physical_aval)) == 1
physical_buf = random_unwrap(val)
physical_sharding = make_key_array_phys_sharding(aval, sharding)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
phys_sharding = physical_sharding(aval, sharding)
physical_result = pxla.batched_device_put(
physical_aval, phys_sharding, [physical_buf] * len(devices), devices)
return random_wrap(physical_result, impl=aval.dtype._impl)
@staticmethod
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
if isinstance(sharding, PmapSharding):
return
phys_aval = core.physical_aval(aval)
hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim)
partitions, _ = op_shardings.get_num_ways_dim_sharded(hlo_s)
num_trailing_dims = phys_aval.ndim - aval.ndim
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
raise AssertionError(
"The trailing dims of extended dtypes should be replicated. Got"
f" sharding: {sharding}, partitions: {partitions}, "
f"num_trailing_dims: {num_trailing_dims}")
@staticmethod
def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# Set the sharding of extended dtypes to be UNCONSTRAINED
# (i.e. XLA will choose) on aval.shape.
# For the trailing dims i.e. the dimension of key_shape on the base_array,
# the sharding is set to be REPLICATED always.
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
return mlir.wrap_with_sharding_op(
ctx, val, aval, xc.HloSharding.replicate().to_proto(),
unspecified_dims=set(range(aval.ndim)))
@staticmethod
def tangent_dtype(_):
return dtypes.float0
@ -571,7 +469,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
arr = x._base_array
phys_sharding = make_key_array_phys_sharding(x.aval, sharding)
phys_sharding = physical_sharding(x.aval, sharding)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)

View File

@ -30,6 +30,7 @@ from jax._src import sharding_specs
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge
from jax._src import core
from jax._src.lib import xla_client as xc
from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded,
is_op_sharding_replicated,
@ -1437,3 +1438,111 @@ def num_addressable_indices(
})
shard_size = tensor_sharding.shard_shape(global_shape)[dim]
return shard_size * num_unique_slices
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
new_op_sharding = hlo_sharding.to_proto().clone()
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * elt_aval.ndim + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)
def is_single_device_sharding(sharding: sharding.Sharding) -> bool:
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.test_pjit_single_device_sharding_add
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
def make_key_array_phys_sharding(aval, sharding):
if is_single_device_sharding(sharding):
return sharding
elif isinstance(sharding, PmapSharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim
phys_sharding_spec = sharding_specs.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
elif isinstance(sharding, NamedSharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
trailing_spec = [None] * elt_aval.ndim
return NamedSharding(
sharding.mesh,
PartitionSpec(*sharding.spec, *trailing_spec))
else:
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
return GSPMDSharding(
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
def physical_sharding(
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
return make_key_array_phys_sharding(aval, sharding)
def get_logical_gspmd_sharding(aval, phys_sharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + elt_aval.ndim)
partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-elt_aval.ndim] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
def check_replicated_trailing_dims(sharding: XLACompatibleSharding, aval):
if isinstance(sharding, PmapSharding):
return
phys_aval = core.physical_aval(aval)
hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim)
partitions, _ = get_num_ways_dim_sharded(hlo_s)
num_trailing_dims = phys_aval.ndim - aval.ndim
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
raise AssertionError(
"The trailing dims of extended dtypes should be replicated. Got"
f" sharding: {sharding}, partitions: {partitions}, "
f"num_trailing_dims: {num_trailing_dims}")
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
# The trailing dims should always be replicated.
check_replicated_trailing_dims(phys_sharding, aval)
if is_single_device_sharding(phys_sharding):
return phys_sharding
elif isinstance(phys_sharding, PmapSharding):
elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype)
logical_sharding_spec = sharding_specs.ShardingSpec(
sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim],
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=phys_sharding.devices,
sharding_spec=logical_sharding_spec)
elif isinstance(phys_sharding, NamedSharding):
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
return _gspmd_to_named_sharding_via_mesh(
logical_gs, phys_sharding.mesh)
else:
return get_logical_gspmd_sharding(aval, phys_sharding)
@util.cache()
def create_mesh_pspec_sharding(
mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
memory_kind: str | None = None) -> NamedSharding:
if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None
return NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
memory_kind=memory_kind)
def _gspmd_to_named_sharding_via_mesh(
out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding:
parsed_pspec = parse_flatten_op_sharding(
out_s._hlo_sharding, mesh)[0]
return create_mesh_pspec_sharding(
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)

View File

@ -653,7 +653,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
ns = aval_in.dtype._rules.physical_sharding(aval_in, ns)
ns = sharding_impls.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
unspecified = set(range(aval_in.ndim)) if auto else set()
@ -667,7 +667,7 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
ns = aval_out.dtype._rules.physical_sharding(aval_out, ns)
ns = sharding_impls.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
unspecified = set(range(aval_out.ndim)) if auto else set()
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)

View File

@ -578,15 +578,6 @@ class EArrayTest(jtu.JaxTestCase):
def physical_element_aval(foo_dtype):
return core.ShapedArray((), dtypes.dtype('float32'))
@staticmethod
def replicate_trailing_dims(ctx, val, aval):
del ctx, aval
return val
@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed):
phys_sharding = out_sharding # unlike KeyTyRules, assume same shape
@ -595,10 +586,6 @@ class EArrayTest(jtu.JaxTestCase):
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
return lambda bufs: earray.EArray(aval, phys_handler(bufs))
@staticmethod
def physical_sharding(aval, sharding):
return sharding # unlike KeyTyRules, assume same shape
@dataclasses.dataclass(frozen=True)
class FooTy(dtypes.ExtendedDType):
name: str = 'foo'

View File

@ -3315,14 +3315,6 @@ class FooTyRules:
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((2,), jnp.dtype('uint32'))
@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def physical_sharding(aval, sharding):
return sharding
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
@ -3341,14 +3333,6 @@ class FooTyRules:
return FooArray(aval.shape, buf)
return handler
@staticmethod
def replicate_trailing_dims(ctx, val, aval):
return val
@staticmethod
def check_replicated_trailing_dims(sharding: jax.sharding.GSPMDSharding, aval):
pass
class FooTy(dtypes.ExtendedDType):
type = dtypes.extended