Use unique in_shardings and out_shardings when iterating of shardings to check for trivial things that don't depend on the number of shardings but only the uniqueness

PiperOrigin-RevId: 682032365
This commit is contained in:
Yash Katariya 2024-10-03 14:29:00 -07:00 committed by jax authors
parent dda4712da0
commit d69a82c140

View File

@ -2173,12 +2173,14 @@ def lower_sharding_computation(
else context_mesh._flat_devices_tuple) else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr # Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same. # should be the same.
unique_intermediate_shardings = list(util.stable_unique( unique_intermediate_shardings = util.stable_unique(
dispatch.get_intermediate_shardings(jaxpr))) list(dispatch.get_intermediate_shardings(jaxpr)))
unique_in_shardings = util.stable_unique(in_shardings)
unique_out_shardings = util.stable_unique(out_shardings)
backend, device_assignment = _get_and_check_device_assignment( backend, device_assignment = _get_and_check_device_assignment(
it.chain( it.chain(
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)), ((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings),
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)), ((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings),
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in unique_intermediate_shardings)), for js, source_info in unique_intermediate_shardings)),
devices_from_context) devices_from_context)
@ -2188,16 +2190,16 @@ def lower_sharding_computation(
committed = bool( committed = bool(
devices_from_context or devices_from_context or
len(device_assignment) > 1 or len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or any(not is_unspecified(i) for i in unique_in_shardings) or
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
any(not is_unspecified(o) for o in out_shardings)) any(not is_unspecified(o) for o in unique_out_shardings))
da_object = _create_da_object(tuple(device_assignment)) da_object = _create_da_object(tuple(device_assignment))
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
all_default_mem_kind = are_all_shardings_default_mem_kind( all_default_mem_kind = are_all_shardings_default_mem_kind(
da_object, da_object,
it.chain(in_shardings, out_shardings, it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings], [js for js, _ in unique_intermediate_shardings],
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
@ -2208,16 +2210,11 @@ def lower_sharding_computation(
closed_jaxpr, in_shardings) closed_jaxpr, in_shardings)
# 2. Build up the HLO # 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(
in_shardings, global_in_avals) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
mesh_shape_tuple = None mesh_shape_tuple = None
if config.use_shardy_partitioner.value or prim_requires_devices: if config.use_shardy_partitioner.value or prim_requires_devices:
for sharding in it.chain(in_shardings, out_shardings, for sharding in it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings]): [js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)): if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)):
if (mesh_shape_tuple is not None and if (mesh_shape_tuple is not None and
@ -2228,6 +2225,11 @@ def lower_sharding_computation(
f" {sharding.mesh.shape_tuple} for another") f" {sharding.mesh.shape_tuple} for another")
mesh_shape_tuple = sharding.mesh.shape_tuple mesh_shape_tuple = sharding.mesh.shape_tuple
semantic_in_shardings = SemanticallyEqualShardings(
in_shardings, global_in_avals) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
(module, keepalive, host_callbacks, unordered_effects, ordered_effects, (module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,