mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Use unique in_shardings and out_shardings when iterating of shardings to check for trivial things that don't depend on the number of shardings but only the uniqueness
PiperOrigin-RevId: 682032365
This commit is contained in:
parent
dda4712da0
commit
d69a82c140
@ -2173,12 +2173,14 @@ def lower_sharding_computation(
|
||||
else context_mesh._flat_devices_tuple)
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
unique_intermediate_shardings = list(util.stable_unique(
|
||||
dispatch.get_intermediate_shardings(jaxpr)))
|
||||
unique_intermediate_shardings = util.stable_unique(
|
||||
list(dispatch.get_intermediate_shardings(jaxpr)))
|
||||
unique_in_shardings = util.stable_unique(in_shardings)
|
||||
unique_out_shardings = util.stable_unique(out_shardings)
|
||||
backend, device_assignment = _get_and_check_device_assignment(
|
||||
it.chain(
|
||||
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)),
|
||||
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)),
|
||||
((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings),
|
||||
((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings),
|
||||
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||
for js, source_info in unique_intermediate_shardings)),
|
||||
devices_from_context)
|
||||
@ -2188,16 +2190,16 @@ def lower_sharding_computation(
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not is_unspecified(i) for i in in_shardings) or
|
||||
any(not is_unspecified(i) for i in unique_in_shardings) or
|
||||
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
any(not is_unspecified(o) for o in unique_out_shardings))
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
da_object,
|
||||
it.chain(in_shardings, out_shardings,
|
||||
it.chain(unique_in_shardings, unique_out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings],
|
||||
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
|
||||
|
||||
@ -2208,16 +2210,11 @@ def lower_sharding_computation(
|
||||
closed_jaxpr, in_shardings)
|
||||
|
||||
# 2. Build up the HLO
|
||||
semantic_in_shardings = SemanticallyEqualShardings(
|
||||
in_shardings, global_in_avals) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(
|
||||
out_shardings, global_out_avals) # type: ignore
|
||||
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
|
||||
mesh_shape_tuple = None
|
||||
if config.use_shardy_partitioner.value or prim_requires_devices:
|
||||
for sharding in it.chain(in_shardings, out_shardings,
|
||||
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings]):
|
||||
if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)):
|
||||
if (mesh_shape_tuple is not None and
|
||||
@ -2228,6 +2225,11 @@ def lower_sharding_computation(
|
||||
f" {sharding.mesh.shape_tuple} for another")
|
||||
mesh_shape_tuple = sharding.mesh.shape_tuple
|
||||
|
||||
semantic_in_shardings = SemanticallyEqualShardings(
|
||||
in_shardings, global_in_avals) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(
|
||||
out_shardings, global_out_avals) # type: ignore
|
||||
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
|
Loading…
x
Reference in New Issue
Block a user