Remove the unused return from prepare_axis_resources

PiperOrigin-RevId: 621738698
This commit is contained in:
Yash Katariya 2024-04-03 22:38:45 -07:00 committed by jax authors
parent bc0eff588a
commit 52f7de0969
3 changed files with 6 additions and 6 deletions

View File

@ -3286,7 +3286,7 @@ def check_array_xla_sharding_layout_match(
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
parsed_pspec = sharding_impls.prepare_axis_resources(
pspec, "pspec to array_mapping")
return _get_array_mapping(parsed_pspec)

View File

@ -378,8 +378,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_shardings = tuple(in_shardings)
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
user_specified_in_shardings = (in_shardings is not None and
not is_unspecified(in_shardings))
@ -2163,7 +2163,7 @@ def with_sharding_constraint(x, shardings):
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
"""
x_flat, tree = tree_flatten(x)
user_shardings, _, _ = prepare_axis_resources(
user_shardings = prepare_axis_resources(
shardings, "shardings", allow_unconstrained_dims=True)
del shardings

View File

@ -1115,7 +1115,7 @@ def preprocess(mesh, spec, parsed_pspec):
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if parsed_pspec is None:
parsed_pspec, _, _ = prepare_axis_resources(
parsed_pspec = prepare_axis_resources(
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
@ -1148,7 +1148,7 @@ def prepare_axis_resources(axis_resources,
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
_check_unique_resources(new_entries, arg_name)
return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef
return tree_util.tree_unflatten(treedef, new_entries)
def _check_unique_resources(axis_resources, arg_name):