mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use unique in_shardings and out_shardings when iterating of shardings to check for trivial things that don't depend on the number of shardings but only the uniqueness
PiperOrigin-RevId: 682032365
This commit is contained in:
parent
dda4712da0
commit
d69a82c140
@ -2173,12 +2173,14 @@ def lower_sharding_computation(
|
|||||||
else context_mesh._flat_devices_tuple)
|
else context_mesh._flat_devices_tuple)
|
||||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||||
# should be the same.
|
# should be the same.
|
||||||
unique_intermediate_shardings = list(util.stable_unique(
|
unique_intermediate_shardings = util.stable_unique(
|
||||||
dispatch.get_intermediate_shardings(jaxpr)))
|
list(dispatch.get_intermediate_shardings(jaxpr)))
|
||||||
|
unique_in_shardings = util.stable_unique(in_shardings)
|
||||||
|
unique_out_shardings = util.stable_unique(out_shardings)
|
||||||
backend, device_assignment = _get_and_check_device_assignment(
|
backend, device_assignment = _get_and_check_device_assignment(
|
||||||
it.chain(
|
it.chain(
|
||||||
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)),
|
((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings),
|
||||||
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)),
|
((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings),
|
||||||
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||||
for js, source_info in unique_intermediate_shardings)),
|
for js, source_info in unique_intermediate_shardings)),
|
||||||
devices_from_context)
|
devices_from_context)
|
||||||
@ -2188,16 +2190,16 @@ def lower_sharding_computation(
|
|||||||
committed = bool(
|
committed = bool(
|
||||||
devices_from_context or
|
devices_from_context or
|
||||||
len(device_assignment) > 1 or
|
len(device_assignment) > 1 or
|
||||||
any(not is_unspecified(i) for i in in_shardings) or
|
any(not is_unspecified(i) for i in unique_in_shardings) or
|
||||||
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
|
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
|
||||||
any(not is_unspecified(o) for o in out_shardings))
|
any(not is_unspecified(o) for o in unique_out_shardings))
|
||||||
|
|
||||||
da_object = _create_da_object(tuple(device_assignment))
|
da_object = _create_da_object(tuple(device_assignment))
|
||||||
|
|
||||||
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
||||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||||
da_object,
|
da_object,
|
||||||
it.chain(in_shardings, out_shardings,
|
it.chain(unique_in_shardings, unique_out_shardings,
|
||||||
[js for js, _ in unique_intermediate_shardings],
|
[js for js, _ in unique_intermediate_shardings],
|
||||||
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
|
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
|
||||||
|
|
||||||
@ -2208,16 +2210,11 @@ def lower_sharding_computation(
|
|||||||
closed_jaxpr, in_shardings)
|
closed_jaxpr, in_shardings)
|
||||||
|
|
||||||
# 2. Build up the HLO
|
# 2. Build up the HLO
|
||||||
semantic_in_shardings = SemanticallyEqualShardings(
|
|
||||||
in_shardings, global_in_avals) # type: ignore
|
|
||||||
semantic_out_shardings = SemanticallyEqualShardings(
|
|
||||||
out_shardings, global_out_avals) # type: ignore
|
|
||||||
|
|
||||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||||
|
|
||||||
mesh_shape_tuple = None
|
mesh_shape_tuple = None
|
||||||
if config.use_shardy_partitioner.value or prim_requires_devices:
|
if config.use_shardy_partitioner.value or prim_requires_devices:
|
||||||
for sharding in it.chain(in_shardings, out_shardings,
|
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
|
||||||
[js for js, _ in unique_intermediate_shardings]):
|
[js for js, _ in unique_intermediate_shardings]):
|
||||||
if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)):
|
if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)):
|
||||||
if (mesh_shape_tuple is not None and
|
if (mesh_shape_tuple is not None and
|
||||||
@ -2228,6 +2225,11 @@ def lower_sharding_computation(
|
|||||||
f" {sharding.mesh.shape_tuple} for another")
|
f" {sharding.mesh.shape_tuple} for another")
|
||||||
mesh_shape_tuple = sharding.mesh.shape_tuple
|
mesh_shape_tuple = sharding.mesh.shape_tuple
|
||||||
|
|
||||||
|
semantic_in_shardings = SemanticallyEqualShardings(
|
||||||
|
in_shardings, global_in_avals) # type: ignore
|
||||||
|
semantic_out_shardings = SemanticallyEqualShardings(
|
||||||
|
out_shardings, global_out_avals) # type: ignore
|
||||||
|
|
||||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user