Allow sharding propagation to input for prng keys whose sharding is not specified.

Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`.

PiperOrigin-RevId: 611246986
This commit is contained in:
Yash Katariya 2024-02-28 15:21:50 -08:00 committed by jax authors
parent 2f7c36c763
commit 217f08236e
8 changed files with 73 additions and 99 deletions

View File

@ -926,13 +926,12 @@ def _array_shard_arg(x, sharding):
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
def _array_global_result_handler(global_aval, out_sharding, committed,
is_out_sharding_from_xla):
def _array_global_result_handler(global_aval, out_sharding, committed):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)
global_aval, out_sharding, committed)
return xc.array_result_handler(
global_aval, out_sharding, committed=committed, _skip_checks=True
)

View File

@ -322,7 +322,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False)
result_handler = pxla.global_aval_to_result_handler(aval, s, committed)
return result_handler(pxla.shard_arg(x, s))
def _override_get_device_assignment(sharding, *args, **kwargs):

View File

@ -224,8 +224,7 @@ local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding, committed: bool,
is_out_sharding_from_xla: bool
aval: core.AbstractValue, out_sharding, committed: bool
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
@ -235,8 +234,6 @@ def global_aval_to_result_handler(
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e.
the sharding is extracted from the HLO.
Returns:
A function for handling the Buffers that will eventually be produced
@ -244,8 +241,7 @@ def global_aval_to_result_handler(
to the user, e.g. an Array.
"""
try:
return global_result_handlers[type(aval)](
aval, out_sharding, committed, is_out_sharding_from_xla)
return global_result_handlers[type(aval)](aval, out_sharding, committed)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
@ -1139,12 +1135,10 @@ def local_avals_to_results_handler(
def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
committed: bool,
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
committed: bool) -> ResultsHandler:
handlers = [
global_aval_to_result_handler(global_aval, s, committed, x)
for global_aval, s, x in safe_zip(global_out_avals, shardings,
are_out_shardings_from_xla)
global_aval_to_result_handler(global_aval, s, committed)
for global_aval, s in safe_zip(global_out_avals, shardings)
]
return ResultsHandler(handlers, shardings, global_out_avals)
@ -2010,12 +2004,6 @@ def lower_sharding_computation(
if xla_extension_version < 240 or hasattr(backend, "compile_replicated"):
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
# TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after
# output sharding of XLA is partially constrained on the trailing dimensions.
in_shardings = tuple(
gs if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else i for i, a in safe_zip(in_shardings, global_in_avals))
da_object = _create_da_object(tuple(device_assignment))
all_default_mem_kind = are_all_shardings_default_mem_kind(
@ -2466,11 +2454,10 @@ _register_out_sharding_handler(
def _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla):
out_shardings, out_avals, orig_in_s, orig_aval):
out = []
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
are_out_sharding_from_xla):
for o, out_aval in safe_zip(out_shardings, out_avals):
if isinstance(o, sharding_impls.GSPMDSharding):
try:
# Only return the same input sharding object if the OpShardings and
@ -2482,21 +2469,19 @@ def _get_out_sharding_from_orig_sharding(
and sharding_impls.are_op_shardings_equal(
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
and o.memory_kind == orig_in_s.memory_kind):
out.append((orig_in_s, False))
out.append(orig_in_s)
else:
out.append((orig_handler(o, orig_in_s), False))
out.append(orig_handler(o, orig_in_s))
except:
out.append((o, from_xla))
out.append(o)
else:
out.append((o, from_xla))
out.append(o)
return out
def maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla, in_avals,
out_avals):
in_shardings, out_shardings, in_avals, out_avals):
if all(hasattr(o, '_original_sharding') for o in out_shardings):
return ([o._original_sharding for o in out_shardings],
(False,) * len(out_shardings))
return [o._original_sharding for o in out_shardings]
orig_in_s = None
orig_aval = None
@ -2507,10 +2492,10 @@ def maybe_get_orig_out_sharding(
orig_aval = aval
break
if orig_in_s is not None:
return zip(*_get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_in_s, orig_aval, are_out_shardings_from_xla))
return _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_in_s, orig_aval)
return out_shardings, are_out_shardings_from_xla
return out_shardings
def _get_layouts_from_executable(
@ -2653,6 +2638,7 @@ def _maybe_get_and_check_in_shardings(
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
new_in_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
@ -2677,17 +2663,17 @@ def _get_out_shardings_from_executable(
xla_executable, device_assignment, len(global_out_avals),
num_ordered_effects, all_default_mem_kind) # type: ignore
if out_shardings_xla is None:
return out_shardings, (False,) * len(global_out_avals)
return out_shardings
new_out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
new_out_shardings = []
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
global_out_avals):
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
new_out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
@ -2700,17 +2686,14 @@ def _get_out_shardings_from_executable(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_out_shardings.append(orig)
are_out_shardings_from_xla.append(False)
return new_out_shardings, are_out_shardings_from_xla
return new_out_shardings
def finalize_out_shardings(out_shardings, are_out_shardings_from_xla,
device_assignment):
def finalize_out_shardings(out_shardings, device_assignment):
if len(device_assignment) == 1:
return ([SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
if isinstance(o, GSPMDSharding) else o for o in out_shardings],
are_out_shardings_from_xla)
return out_shardings, are_out_shardings_from_xla
return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
if isinstance(o, GSPMDSharding) else o for o in out_shardings]
return out_shardings
@dataclasses.dataclass
@ -2723,7 +2706,6 @@ class UnloadedMeshExecutable:
output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
committed: bool
are_out_shardings_from_xla: Sequence[bool]
name: str
unordered_effects: list[core.Effect]
ordered_effects: list[core.Effect]
@ -2744,8 +2726,7 @@ class UnloadedMeshExecutable:
handle_args = InputsHandler(
self.input_shardings, self.xla_executable.local_devices(), input_indices)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed,
self.are_out_shardings_from_xla) # type: ignore # arg-type
self.output_avals, self.output_shardings, self.committed) # type: ignore # arg-type
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable, self.name, self.backend, handle_args,
@ -2833,11 +2814,8 @@ class UnloadedMeshExecutable:
xla_executable, mesh)
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings_tuple = [
(x, True) if is_auto(o) else (o, False)
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
out_shardings = [x if is_auto(o) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]
else:
if pmap_nreps == 1:
assert mesh is None
@ -2845,13 +2823,12 @@ class UnloadedMeshExecutable:
in_shardings = _maybe_get_and_check_in_shardings(
xla_executable, in_shardings, tuple(da), global_in_avals,
len(ordered_effects))
out_shardings, are_out_shardings_from_xla = _get_out_shardings_from_executable(
out_shardings = _get_out_shardings_from_executable(
xla_executable, out_shardings, tuple(da), global_out_avals,
len(ordered_effects), all_default_mem_kind)
else:
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if xla_extension_version >= 217:
in_layouts, out_layouts = _get_layouts_from_executable(
@ -2860,12 +2837,10 @@ class UnloadedMeshExecutable:
assert all(i is None for i in in_layouts)
assert all(o is None for o in out_layouts)
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla,
global_in_avals, global_out_avals)
out_shardings = maybe_get_orig_out_sharding(
in_shardings, out_shardings, global_in_avals, global_out_avals)
out_shardings, are_out_shardings_from_xla = finalize_out_shardings(
out_shardings, are_out_shardings_from_xla, da)
out_shardings = finalize_out_shardings(out_shardings, da)
return UnloadedMeshExecutable(
xla_executable=xla_executable,
@ -2876,7 +2851,6 @@ class UnloadedMeshExecutable:
output_avals=global_out_avals,
output_shardings=out_shardings, # type: ignore # arg-type
committed=committed,
are_out_shardings_from_xla=are_out_shardings_from_xla,
name=name,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,

View File

@ -5082,8 +5082,7 @@ class BIntRules:
return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
def global_sharded_result_handler(aval, out_sharding, committed):
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
@ -5091,8 +5090,7 @@ class BIntRules:
raise NotImplementedError # TODO(mattjj)
else:
phys_sharding = out_sharding
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
def handler(bufs):
return core.DArray(aval, phys_handler(bufs))
@ -5102,6 +5100,10 @@ class BIntRules:
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
return hlo_sharding
@staticmethod
def logical_op_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def convert_from(bint_dtype, other_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))

View File

@ -25,7 +25,7 @@ from jax._src.lib import xla_client as xc
def get_num_ways_dim_sharded(
hlo_sharding: xc.HloSharding) -> tuple[Sequence[int], int]:
hlo_sharding: xc.HloSharding) -> tuple[list[int], int]:
if hlo_sharding.is_replicated(): # type: ignore
return [], 1
partitions = hlo_sharding.tile_assignment_dimensions()
@ -42,7 +42,7 @@ def get_num_ways_dim_sharded(
if replicate_on_last_tile_dim:
num_replicas = partitions[-1]
partitions = partitions[:-1]
return partitions, num_replicas
return list(partitions), num_replicas
def is_op_sharding_replicated(op: xc.OpSharding | xc.HloSharding) -> bool:

View File

@ -1701,8 +1701,8 @@ def _pjit_batcher_for_sharding(
else:
new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore
if hasattr(s, '_original_sharding'):
vmapped_s, _ = pxla._get_out_sharding_from_orig_sharding(
[new_gs], [None], s._original_sharding, None, [False])[0] # type: ignore
vmapped_s = pxla._get_out_sharding_from_orig_sharding(
[new_gs], [None], s._original_sharding, None)[0] # type: ignore
new_gs = to_gspmd_sharding(vmapped_s, ndim)
return new_gs
else:

View File

@ -318,7 +318,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape):
base_ndim = len(impl.key_shape)
return base_arr_shape[:-base_ndim]
def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
def make_key_array_phys_sharding(aval, sharding):
if dispatch.is_single_device_sharding(sharding):
return sharding
elif isinstance(sharding, PmapSharding):
@ -335,8 +335,6 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
return NamedSharding(
sharding.mesh,
PartitionSpec(*sharding.spec, *trailing_spec))
elif is_sharding_from_xla:
return sharding
else:
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
return GSPMDSharding(
@ -367,11 +365,11 @@ class KeyTyRules:
@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
op_sharding_proto = hlo_sharding.to_proto() # type: ignore
new_op_sharding = op_sharding_proto.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
suffix = [tad.pop()] if op_sharding_proto.replicate_on_last_tile_dim else []
tad.extend([1] * len(key_shape) + suffix)
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)
@ -393,11 +391,14 @@ class KeyTyRules:
PartitionSpec(*phys_sharding.spec[:-len(key_shape)]))
else:
key_shape = aval.dtype._impl.key_shape
phys_op_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape)).to_proto()
logical_op_sharding = phys_op_sharding.clone()
tad = list(logical_op_sharding.tile_assignment_dimensions)
tad = tad[:-len(key_shape)]
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape))
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-len(key_shape)] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
@ -417,8 +418,7 @@ class KeyTyRules:
# set up a grounded sharding (with a grounded sharding spec)
if isinstance(sharding, (PmapSharding, NamedSharding)):
phys_sharding = make_key_array_phys_sharding(
aval, sharding, is_sharding_from_xla=False)
phys_sharding = make_key_array_phys_sharding(aval, sharding)
else:
assert False, f'impossible sharding {sharding} in local sharded result handler'
@ -436,15 +436,12 @@ class KeyTyRules:
return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
def global_sharded_result_handler(aval, out_sharding, committed):
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_sharding = make_key_array_phys_sharding(
aval, out_sharding, is_out_sharding_from_xla)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
phys_sharding = make_key_array_phys_sharding(aval, out_sharding)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
def handler(bufs):
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
return handler
@ -455,8 +452,8 @@ class KeyTyRules:
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_arrays = [random_unwrap(arr) for arr in arrays]
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
phys_sharding = make_key_array_phys_sharding(aval, sharding)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
phys_result = phys_handler(phys_arrays)
return PRNGKeyArray(aval.dtype._impl, phys_result)
@ -464,7 +461,7 @@ class KeyTyRules:
def device_put_sharded(vals, aval, sharding, devices):
physical_aval = core.physical_aval(aval)
physical_buffers = tree_util.tree_map(random_unwrap, vals)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
physical_sharding = make_key_array_phys_sharding(aval, sharding)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices))
return random_wrap(physical_result, impl=aval.dtype._impl)
@ -473,7 +470,7 @@ class KeyTyRules:
physical_aval = core.physical_aval(aval)
assert len(xla.aval_to_xla_shapes(physical_aval)) == 1
physical_buf = random_unwrap(val)
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
physical_sharding = make_key_array_phys_sharding(aval, sharding)
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
return random_wrap(physical_result, impl=aval.dtype._impl)
@ -554,8 +551,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
arr = x._base_array
phys_sharding = make_key_array_phys_sharding(
x.aval, sharding, is_sharding_from_xla=False)
phys_sharding = make_key_array_phys_sharding(x.aval, sharding)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)

View File

@ -2977,6 +2977,10 @@ class FooTyRules:
new_op_sharding.tile_assignment_dimensions = [*tad, 1]
return xc.HloSharding.from_proto(new_op_sharding)
@staticmethod
def logical_op_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
@ -2985,8 +2989,7 @@ class FooTyRules:
return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
def global_sharded_result_handler(aval, out_sharding, committed):
def handler(arr):
from jax._src.array import ArrayImpl
if isinstance(arr, ArrayImpl):