mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove the unused return from prepare_axis_resources
PiperOrigin-RevId: 621738698
This commit is contained in:
parent
bc0eff588a
commit
52f7de0969
@ -3286,7 +3286,7 @@ def check_array_xla_sharding_layout_match(
|
||||
|
||||
|
||||
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
|
||||
parsed_pspec = sharding_impls.prepare_axis_resources(
|
||||
pspec, "pspec to array_mapping")
|
||||
return _get_array_mapping(parsed_pspec)
|
||||
|
||||
|
@ -378,8 +378,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_shardings = tuple(in_shardings)
|
||||
|
||||
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
|
||||
out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
|
||||
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
|
||||
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
|
||||
|
||||
user_specified_in_shardings = (in_shardings is not None and
|
||||
not is_unspecified(in_shardings))
|
||||
@ -2163,7 +2163,7 @@ def with_sharding_constraint(x, shardings):
|
||||
.. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
|
||||
"""
|
||||
x_flat, tree = tree_flatten(x)
|
||||
user_shardings, _, _ = prepare_axis_resources(
|
||||
user_shardings = prepare_axis_resources(
|
||||
shardings, "shardings", allow_unconstrained_dims=True)
|
||||
del shardings
|
||||
|
||||
|
@ -1115,7 +1115,7 @@ def preprocess(mesh, spec, parsed_pspec):
|
||||
# TODO(yaskatariya): Remove this and replace this with a normalized
|
||||
# representation of Parsed Pspec
|
||||
if parsed_pspec is None:
|
||||
parsed_pspec, _, _ = prepare_axis_resources(
|
||||
parsed_pspec = prepare_axis_resources(
|
||||
PartitionSpec() if spec is None else spec,
|
||||
"NamedSharding spec", allow_unconstrained_dims=True)
|
||||
|
||||
@ -1148,7 +1148,7 @@ def prepare_axis_resources(axis_resources,
|
||||
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
|
||||
|
||||
_check_unique_resources(new_entries, arg_name)
|
||||
return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef
|
||||
return tree_util.tree_unflatten(treedef, new_entries)
|
||||
|
||||
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
|
Loading…
x
Reference in New Issue
Block a user