mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
try_matching_out_with_in_spec_for_all_auto
should only work for NamedSharding
PiperOrigin-RevId: 722854686
This commit is contained in:
parent
0abd9538ce
commit
5e7e6911f4
@ -2588,7 +2588,7 @@ def try_matching_out_with_in_spec_for_all_auto(
|
||||
orig_out_shardings, new_out_shardings, out_avals, in_shardings, in_avals):
|
||||
recover_in_s, recover_in_aval = None, None
|
||||
for in_s, in_aval in safe_zip(in_shardings, in_avals):
|
||||
if in_s is not None and type(in_s) in _orig_out_sharding_handlers:
|
||||
if isinstance(in_s, NamedSharding):
|
||||
recover_in_s, recover_in_aval = in_s, in_aval
|
||||
break
|
||||
if recover_in_s is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user