From d69a82c140d25ab87600a0d200ac37bbb28323b8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 3 Oct 2024 14:29:00 -0700 Subject: [PATCH] Use unique in_shardings and out_shardings when iterating of shardings to check for trivial things that don't depend on the number of shardings but only the uniqueness PiperOrigin-RevId: 682032365 --- jax/_src/interpreters/pxla.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4c134f266..650f3fb6c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2173,12 +2173,14 @@ def lower_sharding_computation( else context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. - unique_intermediate_shardings = list(util.stable_unique( - dispatch.get_intermediate_shardings(jaxpr))) + unique_intermediate_shardings = util.stable_unique( + list(dispatch.get_intermediate_shardings(jaxpr))) + unique_in_shardings = util.stable_unique(in_shardings) + unique_out_shardings = util.stable_unique(out_shardings) backend, device_assignment = _get_and_check_device_assignment( it.chain( - ((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)), - ((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)), + ((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), + ((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) for js, source_info in unique_intermediate_shardings)), devices_from_context) @@ -2188,16 +2190,16 @@ def lower_sharding_computation( committed = bool( devices_from_context or len(device_assignment) > 1 or - any(not is_unspecified(i) for i in in_shardings) or + any(not is_unspecified(i) for i in unique_in_shardings) or any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or - any(not is_unspecified(o) for o in out_shardings)) + any(not is_unspecified(o) for o in unique_out_shardings)) da_object = _create_da_object(tuple(device_assignment)) transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, - it.chain(in_shardings, out_shardings, + it.chain(unique_in_shardings, unique_out_shardings, [js for js, _ in unique_intermediate_shardings], transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types @@ -2208,16 +2210,11 @@ def lower_sharding_computation( closed_jaxpr, in_shardings) # 2. Build up the HLO - semantic_in_shardings = SemanticallyEqualShardings( - in_shardings, global_in_avals) # type: ignore - semantic_out_shardings = SemanticallyEqualShardings( - out_shardings, global_out_avals) # type: ignore - prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) mesh_shape_tuple = None if config.use_shardy_partitioner.value or prim_requires_devices: - for sharding in it.chain(in_shardings, out_shardings, + for sharding in it.chain(unique_in_shardings, unique_out_shardings, [js for js, _ in unique_intermediate_shardings]): if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)): if (mesh_shape_tuple is not None and @@ -2228,6 +2225,11 @@ def lower_sharding_computation( f" {sharding.mesh.shape_tuple} for another") mesh_shape_tuple = sharding.mesh.shape_tuple + semantic_in_shardings = SemanticallyEqualShardings( + in_shardings, global_in_avals) # type: ignore + semantic_out_shardings = SemanticallyEqualShardings( + out_shardings, global_out_avals) # type: ignore + (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,