Convert tuple to DeviceAssignment on the replicated compilation path.

PiperOrigin-RevId: 533258935
This commit is contained in:
jax authors 2023-05-18 14:52:37 -07:00
parent 08169291a4
commit 6034e87ddf

View File

@ -2918,7 +2918,6 @@ def _compile_replicated_mesh_executable_from_hlo(
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
backend, da, committed, pmap_nreps, jaxpr_debug_info):
assert not auto_spmd_lowering
assert isinstance(da, _DeviceAssignment)
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings