mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Convert tuple to DeviceAssignment on the replicated compilation path.
PiperOrigin-RevId: 533258935
This commit is contained in:
parent
08169291a4
commit
6034e87ddf
@ -2918,7 +2918,6 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
|
||||
backend, da, committed, pmap_nreps, jaxpr_debug_info):
|
||||
assert not auto_spmd_lowering
|
||||
assert isinstance(da, _DeviceAssignment)
|
||||
in_shardings = semantics_in_shardings.shardings
|
||||
out_shardings = semantics_out_shardings.shardings
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user