mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Allow sharding propagation to input for prng keys whose sharding is not specified.
Convert shardings returned by XLA (when propagation is on for input and output) for extended dtypes to user shardings which allows to remove `are_out_shardings_from_xla`. PiperOrigin-RevId: 611246986
This commit is contained in:
parent
2f7c36c763
commit
217f08236e
@ -926,13 +926,12 @@ def _array_shard_arg(x, sharding):
|
||||
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
||||
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
if global_aval.dtype == dtypes.float0:
|
||||
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
|
||||
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
|
||||
return global_aval.dtype._rules.global_sharded_result_handler(
|
||||
global_aval, out_sharding, committed, is_out_sharding_from_xla)
|
||||
global_aval, out_sharding, committed)
|
||||
return xc.array_result_handler(
|
||||
global_aval, out_sharding, committed=committed, _skip_checks=True
|
||||
)
|
||||
|
@ -322,7 +322,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
||||
|
||||
|
||||
def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
|
||||
result_handler = pxla.global_aval_to_result_handler(aval, s, committed, False)
|
||||
result_handler = pxla.global_aval_to_result_handler(aval, s, committed)
|
||||
return result_handler(pxla.shard_arg(x, s))
|
||||
|
||||
def _override_get_device_assignment(sharding, *args, **kwargs):
|
||||
|
@ -224,8 +224,7 @@ local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
|
||||
def global_aval_to_result_handler(
|
||||
aval: core.AbstractValue, out_sharding, committed: bool,
|
||||
is_out_sharding_from_xla: bool
|
||||
aval: core.AbstractValue, out_sharding, committed: bool
|
||||
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
@ -235,8 +234,6 @@ def global_aval_to_result_handler(
|
||||
Used for creating GSDAs.
|
||||
global_mesh: The global device mesh that generated this output. Used
|
||||
for creating GSDAs.
|
||||
is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e.
|
||||
the sharding is extracted from the HLO.
|
||||
|
||||
Returns:
|
||||
A function for handling the Buffers that will eventually be produced
|
||||
@ -244,8 +241,7 @@ def global_aval_to_result_handler(
|
||||
to the user, e.g. an Array.
|
||||
"""
|
||||
try:
|
||||
return global_result_handlers[type(aval)](
|
||||
aval, out_sharding, committed, is_out_sharding_from_xla)
|
||||
return global_result_handlers[type(aval)](aval, out_sharding, committed)
|
||||
except KeyError as err:
|
||||
raise TypeError(
|
||||
f"No pxla_result_handler for type: {type(aval)}") from err
|
||||
@ -1139,12 +1135,10 @@ def local_avals_to_results_handler(
|
||||
def global_avals_to_results_handler(
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
committed: bool,
|
||||
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
|
||||
committed: bool) -> ResultsHandler:
|
||||
handlers = [
|
||||
global_aval_to_result_handler(global_aval, s, committed, x)
|
||||
for global_aval, s, x in safe_zip(global_out_avals, shardings,
|
||||
are_out_shardings_from_xla)
|
||||
global_aval_to_result_handler(global_aval, s, committed)
|
||||
for global_aval, s in safe_zip(global_out_avals, shardings)
|
||||
]
|
||||
return ResultsHandler(handlers, shardings, global_out_avals)
|
||||
|
||||
@ -2010,12 +2004,6 @@ def lower_sharding_computation(
|
||||
if xla_extension_version < 240 or hasattr(backend, "compile_replicated"):
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
# TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after
|
||||
# output sharding of XLA is partially constrained on the trailing dimensions.
|
||||
in_shardings = tuple(
|
||||
gs if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
||||
else i for i, a in safe_zip(in_shardings, global_in_avals))
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
@ -2466,11 +2454,10 @@ _register_out_sharding_handler(
|
||||
|
||||
|
||||
def _get_out_sharding_from_orig_sharding(
|
||||
out_shardings, out_avals, orig_in_s, orig_aval, are_out_sharding_from_xla):
|
||||
out_shardings, out_avals, orig_in_s, orig_aval):
|
||||
out = []
|
||||
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
|
||||
for o, out_aval, from_xla in safe_zip(out_shardings, out_avals,
|
||||
are_out_sharding_from_xla):
|
||||
for o, out_aval in safe_zip(out_shardings, out_avals):
|
||||
if isinstance(o, sharding_impls.GSPMDSharding):
|
||||
try:
|
||||
# Only return the same input sharding object if the OpShardings and
|
||||
@ -2482,21 +2469,19 @@ def _get_out_sharding_from_orig_sharding(
|
||||
and sharding_impls.are_op_shardings_equal(
|
||||
o._hlo_sharding, orig_in_s._to_xla_hlo_sharding(orig_aval.ndim))
|
||||
and o.memory_kind == orig_in_s.memory_kind):
|
||||
out.append((orig_in_s, False))
|
||||
out.append(orig_in_s)
|
||||
else:
|
||||
out.append((orig_handler(o, orig_in_s), False))
|
||||
out.append(orig_handler(o, orig_in_s))
|
||||
except:
|
||||
out.append((o, from_xla))
|
||||
out.append(o)
|
||||
else:
|
||||
out.append((o, from_xla))
|
||||
out.append(o)
|
||||
return out
|
||||
|
||||
def maybe_get_orig_out_sharding(
|
||||
in_shardings, out_shardings, are_out_shardings_from_xla, in_avals,
|
||||
out_avals):
|
||||
in_shardings, out_shardings, in_avals, out_avals):
|
||||
if all(hasattr(o, '_original_sharding') for o in out_shardings):
|
||||
return ([o._original_sharding for o in out_shardings],
|
||||
(False,) * len(out_shardings))
|
||||
return [o._original_sharding for o in out_shardings]
|
||||
|
||||
orig_in_s = None
|
||||
orig_aval = None
|
||||
@ -2507,10 +2492,10 @@ def maybe_get_orig_out_sharding(
|
||||
orig_aval = aval
|
||||
break
|
||||
if orig_in_s is not None:
|
||||
return zip(*_get_out_sharding_from_orig_sharding(
|
||||
out_shardings, out_avals, orig_in_s, orig_aval, are_out_shardings_from_xla))
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
out_shardings, out_avals, orig_in_s, orig_aval)
|
||||
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
return out_shardings
|
||||
|
||||
|
||||
def _get_layouts_from_executable(
|
||||
@ -2653,6 +2638,7 @@ def _maybe_get_and_check_in_shardings(
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
|
||||
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
@ -2677,17 +2663,17 @@ def _get_out_shardings_from_executable(
|
||||
xla_executable, device_assignment, len(global_out_avals),
|
||||
num_ordered_effects, all_default_mem_kind) # type: ignore
|
||||
if out_shardings_xla is None:
|
||||
return out_shardings, (False,) * len(global_out_avals)
|
||||
return out_shardings
|
||||
|
||||
new_out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
new_out_shardings = []
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
|
||||
global_out_avals):
|
||||
if is_unspecified(orig):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
aval.dtype._rules.check_replicated_trailing_dims(xla_s, aval)
|
||||
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
|
||||
new_out_shardings.append(xla_s)
|
||||
are_out_shardings_from_xla.append(True)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
@ -2700,17 +2686,14 @@ def _get_out_shardings_from_executable(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
new_out_shardings.append(orig)
|
||||
are_out_shardings_from_xla.append(False)
|
||||
return new_out_shardings, are_out_shardings_from_xla
|
||||
return new_out_shardings
|
||||
|
||||
|
||||
def finalize_out_shardings(out_shardings, are_out_shardings_from_xla,
|
||||
device_assignment):
|
||||
def finalize_out_shardings(out_shardings, device_assignment):
|
||||
if len(device_assignment) == 1:
|
||||
return ([SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
|
||||
if isinstance(o, GSPMDSharding) else o for o in out_shardings],
|
||||
are_out_shardings_from_xla)
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind)
|
||||
if isinstance(o, GSPMDSharding) else o for o in out_shardings]
|
||||
return out_shardings
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -2723,7 +2706,6 @@ class UnloadedMeshExecutable:
|
||||
output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
committed: bool
|
||||
are_out_shardings_from_xla: Sequence[bool]
|
||||
name: str
|
||||
unordered_effects: list[core.Effect]
|
||||
ordered_effects: list[core.Effect]
|
||||
@ -2744,8 +2726,7 @@ class UnloadedMeshExecutable:
|
||||
handle_args = InputsHandler(
|
||||
self.input_shardings, self.xla_executable.local_devices(), input_indices)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
self.output_avals, self.output_shardings, self.committed,
|
||||
self.are_out_shardings_from_xla) # type: ignore # arg-type
|
||||
self.output_avals, self.output_shardings, self.committed) # type: ignore # arg-type
|
||||
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
@ -2833,11 +2814,8 @@ class UnloadedMeshExecutable:
|
||||
xla_executable, mesh)
|
||||
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
|
||||
for x, i in safe_zip(in_shardings_xla, in_shardings)]
|
||||
out_shardings_tuple = [
|
||||
(x, True) if is_auto(o) else (o, False)
|
||||
for x, o in safe_zip(out_shardings_xla, out_shardings)
|
||||
]
|
||||
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
|
||||
out_shardings = [x if is_auto(o) else o
|
||||
for x, o in safe_zip(out_shardings_xla, out_shardings)]
|
||||
else:
|
||||
if pmap_nreps == 1:
|
||||
assert mesh is None
|
||||
@ -2845,13 +2823,12 @@ class UnloadedMeshExecutable:
|
||||
in_shardings = _maybe_get_and_check_in_shardings(
|
||||
xla_executable, in_shardings, tuple(da), global_in_avals,
|
||||
len(ordered_effects))
|
||||
out_shardings, are_out_shardings_from_xla = _get_out_shardings_from_executable(
|
||||
out_shardings = _get_out_shardings_from_executable(
|
||||
xla_executable, out_shardings, tuple(da), global_out_avals,
|
||||
len(ordered_effects), all_default_mem_kind)
|
||||
else:
|
||||
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
||||
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
|
||||
if xla_extension_version >= 217:
|
||||
in_layouts, out_layouts = _get_layouts_from_executable(
|
||||
@ -2860,12 +2837,10 @@ class UnloadedMeshExecutable:
|
||||
assert all(i is None for i in in_layouts)
|
||||
assert all(o is None for o in out_layouts)
|
||||
|
||||
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
|
||||
in_shardings, out_shardings, are_out_shardings_from_xla,
|
||||
global_in_avals, global_out_avals)
|
||||
out_shardings = maybe_get_orig_out_sharding(
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals)
|
||||
|
||||
out_shardings, are_out_shardings_from_xla = finalize_out_shardings(
|
||||
out_shardings, are_out_shardings_from_xla, da)
|
||||
out_shardings = finalize_out_shardings(out_shardings, da)
|
||||
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
@ -2876,7 +2851,6 @@ class UnloadedMeshExecutable:
|
||||
output_avals=global_out_avals,
|
||||
output_shardings=out_shardings, # type: ignore # arg-type
|
||||
committed=committed,
|
||||
are_out_shardings_from_xla=are_out_shardings_from_xla,
|
||||
name=name,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
|
@ -5082,8 +5082,7 @@ class BIntRules:
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
def global_sharded_result_handler(aval, out_sharding, committed):
|
||||
phys_aval = core.physical_aval(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
|
||||
@ -5091,8 +5090,7 @@ class BIntRules:
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
else:
|
||||
phys_sharding = out_sharding
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
|
||||
is_out_sharding_from_xla)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
|
||||
def handler(bufs):
|
||||
return core.DArray(aval, phys_handler(bufs))
|
||||
@ -5102,6 +5100,10 @@ class BIntRules:
|
||||
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
|
||||
return hlo_sharding
|
||||
|
||||
@staticmethod
|
||||
def logical_op_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def convert_from(bint_dtype, other_dtype) -> bool:
|
||||
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
||||
|
@ -25,7 +25,7 @@ from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
def get_num_ways_dim_sharded(
|
||||
hlo_sharding: xc.HloSharding) -> tuple[Sequence[int], int]:
|
||||
hlo_sharding: xc.HloSharding) -> tuple[list[int], int]:
|
||||
if hlo_sharding.is_replicated(): # type: ignore
|
||||
return [], 1
|
||||
partitions = hlo_sharding.tile_assignment_dimensions()
|
||||
@ -42,7 +42,7 @@ def get_num_ways_dim_sharded(
|
||||
if replicate_on_last_tile_dim:
|
||||
num_replicas = partitions[-1]
|
||||
partitions = partitions[:-1]
|
||||
return partitions, num_replicas
|
||||
return list(partitions), num_replicas
|
||||
|
||||
|
||||
def is_op_sharding_replicated(op: xc.OpSharding | xc.HloSharding) -> bool:
|
||||
|
@ -1701,8 +1701,8 @@ def _pjit_batcher_for_sharding(
|
||||
else:
|
||||
new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore
|
||||
if hasattr(s, '_original_sharding'):
|
||||
vmapped_s, _ = pxla._get_out_sharding_from_orig_sharding(
|
||||
[new_gs], [None], s._original_sharding, None, [False])[0] # type: ignore
|
||||
vmapped_s = pxla._get_out_sharding_from_orig_sharding(
|
||||
[new_gs], [None], s._original_sharding, None)[0] # type: ignore
|
||||
new_gs = to_gspmd_sharding(vmapped_s, ndim)
|
||||
return new_gs
|
||||
else:
|
||||
|
@ -318,7 +318,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape):
|
||||
base_ndim = len(impl.key_shape)
|
||||
return base_arr_shape[:-base_ndim]
|
||||
|
||||
def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
|
||||
def make_key_array_phys_sharding(aval, sharding):
|
||||
if dispatch.is_single_device_sharding(sharding):
|
||||
return sharding
|
||||
elif isinstance(sharding, PmapSharding):
|
||||
@ -335,8 +335,6 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
|
||||
return NamedSharding(
|
||||
sharding.mesh,
|
||||
PartitionSpec(*sharding.spec, *trailing_spec))
|
||||
elif is_sharding_from_xla:
|
||||
return sharding
|
||||
else:
|
||||
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
|
||||
return GSPMDSharding(
|
||||
@ -367,11 +365,11 @@ class KeyTyRules:
|
||||
@staticmethod
|
||||
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
op_sharding_proto = hlo_sharding.to_proto() # type: ignore
|
||||
new_op_sharding = op_sharding_proto.clone()
|
||||
tad = list(new_op_sharding.tile_assignment_dimensions)
|
||||
suffix = [tad.pop()] if op_sharding_proto.replicate_on_last_tile_dim else []
|
||||
tad.extend([1] * len(key_shape) + suffix)
|
||||
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
|
||||
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
|
||||
hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
tad = partitions + [1] * len(key_shape) + suffix
|
||||
new_op_sharding.tile_assignment_dimensions = tad
|
||||
return xc.HloSharding.from_proto(new_op_sharding)
|
||||
|
||||
@ -393,11 +391,14 @@ class KeyTyRules:
|
||||
PartitionSpec(*phys_sharding.spec[:-len(key_shape)]))
|
||||
else:
|
||||
key_shape = aval.dtype._impl.key_shape
|
||||
phys_op_sharding = phys_sharding._to_xla_hlo_sharding(
|
||||
aval.ndim + len(key_shape)).to_proto()
|
||||
logical_op_sharding = phys_op_sharding.clone()
|
||||
tad = list(logical_op_sharding.tile_assignment_dimensions)
|
||||
tad = tad[:-len(key_shape)]
|
||||
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
|
||||
aval.ndim + len(key_shape))
|
||||
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
|
||||
phys_hlo_sharding)
|
||||
suffix = [] if num_replicas == 1 else [num_replicas]
|
||||
# Create logical sharding by cutting off the replicated trailing dims.
|
||||
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
|
||||
tad = partitions[:-len(key_shape)] + suffix
|
||||
logical_op_sharding.tile_assignment_dimensions = tad
|
||||
return GSPMDSharding(phys_sharding._device_assignment,
|
||||
xc.HloSharding.from_proto(logical_op_sharding))
|
||||
@ -417,8 +418,7 @@ class KeyTyRules:
|
||||
|
||||
# set up a grounded sharding (with a grounded sharding spec)
|
||||
if isinstance(sharding, (PmapSharding, NamedSharding)):
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
aval, sharding, is_sharding_from_xla=False)
|
||||
phys_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
else:
|
||||
assert False, f'impossible sharding {sharding} in local sharded result handler'
|
||||
|
||||
@ -436,15 +436,12 @@ class KeyTyRules:
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
def global_sharded_result_handler(aval, out_sharding, committed):
|
||||
phys_aval = core.physical_aval(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
aval, out_sharding, is_out_sharding_from_xla)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
|
||||
is_out_sharding_from_xla)
|
||||
phys_sharding = make_key_array_phys_sharding(aval, out_sharding)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
def handler(bufs):
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
|
||||
return handler
|
||||
@ -455,8 +452,8 @@ class KeyTyRules:
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
phys_arrays = [random_unwrap(arr) for arr in arrays]
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
|
||||
phys_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
phys_result = phys_handler(phys_arrays)
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_result)
|
||||
|
||||
@ -464,7 +461,7 @@ class KeyTyRules:
|
||||
def device_put_sharded(vals, aval, sharding, devices):
|
||||
physical_aval = core.physical_aval(aval)
|
||||
physical_buffers = tree_util.tree_map(random_unwrap, vals)
|
||||
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
|
||||
physical_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, physical_buffers, list(devices))
|
||||
return random_wrap(physical_result, impl=aval.dtype._impl)
|
||||
|
||||
@ -473,7 +470,7 @@ class KeyTyRules:
|
||||
physical_aval = core.physical_aval(aval)
|
||||
assert len(xla.aval_to_xla_shapes(physical_aval)) == 1
|
||||
physical_buf = random_unwrap(val)
|
||||
physical_sharding = make_key_array_phys_sharding(aval, sharding, False)
|
||||
physical_sharding = make_key_array_phys_sharding(aval, sharding)
|
||||
physical_result = pxla.batched_device_put(physical_aval, physical_sharding, [physical_buf] * len(devices), devices)
|
||||
return random_wrap(physical_result, impl=aval.dtype._impl)
|
||||
|
||||
@ -554,8 +551,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
|
||||
arr = x._base_array
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
x.aval, sharding, is_sharding_from_xla=False)
|
||||
phys_sharding = make_key_array_phys_sharding(x.aval, sharding)
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
|
||||
|
||||
|
||||
|
@ -2977,6 +2977,10 @@ class FooTyRules:
|
||||
new_op_sharding.tile_assignment_dimensions = [*tad, 1]
|
||||
return xc.HloSharding.from_proto(new_op_sharding)
|
||||
|
||||
@staticmethod
|
||||
def logical_op_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
@ -2985,8 +2989,7 @@ class FooTyRules:
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
def global_sharded_result_handler(aval, out_sharding, committed):
|
||||
def handler(arr):
|
||||
from jax._src.array import ArrayImpl
|
||||
if isinstance(arr, ArrayImpl):
|
||||
|
Loading…
x
Reference in New Issue
Block a user