diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b9cdc58df..60dd5d073 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3286,7 +3286,7 @@ def check_array_xla_sharding_layout_match( def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: - parsed_pspec, _, _ = sharding_impls.prepare_axis_resources( + parsed_pspec = sharding_impls.prepare_axis_resources( pspec, "pspec to array_mapping") return _get_array_mapping(parsed_pspec) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fcae50864..69d29c6e0 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -378,8 +378,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, # rather than raising an error. https://github.com/google/jax/issues/2367 in_shardings = tuple(in_shardings) - in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings') - out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings') + in_shardings = prepare_axis_resources(in_shardings, 'in_shardings') + out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') user_specified_in_shardings = (in_shardings is not None and not is_unspecified(in_shardings)) @@ -2163,7 +2163,7 @@ def with_sharding_constraint(x, shardings): .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) - user_shardings, _, _ = prepare_axis_resources( + user_shardings = prepare_axis_resources( shardings, "shardings", allow_unconstrained_dims=True) del shardings diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 8b62e19b4..305fafb47 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1115,7 +1115,7 @@ def preprocess(mesh, spec, parsed_pspec): # TODO(yaskatariya): Remove this and replace this with a normalized # representation of Parsed Pspec if parsed_pspec is None: - parsed_pspec, _, _ = prepare_axis_resources( + parsed_pspec = prepare_axis_resources( PartitionSpec() if spec is None else spec, "NamedSharding spec", allow_unconstrained_dims=True) @@ -1148,7 +1148,7 @@ def prepare_axis_resources(axis_resources, entry, what, allow_unconstrained_dims=allow_unconstrained_dims)) _check_unique_resources(new_entries, arg_name) - return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef + return tree_util.tree_unflatten(treedef, new_entries) def _check_unique_resources(axis_resources, arg_name):