try_matching_out_with_in_spec_for_all_auto should only work for NamedSharding

PiperOrigin-RevId: 722854686
This commit is contained in:
Yash Katariya 2025-02-03 17:12:33 -08:00 committed by jax authors
parent 0abd9538ce
commit 5e7e6911f4

View File

@ -2588,7 +2588,7 @@ def try_matching_out_with_in_spec_for_all_auto(
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
recover_in_s, recover_in_aval = None, None
for in_s, in_aval in safe_zip(in_shardings, in_avals):
if in_s is not None and type(in_s) in _orig_out_sharding_handlers:
if isinstance(in_s, NamedSharding):
recover_in_s, recover_in_aval = in_s, in_aval
break
if recover_in_s is None: