diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 160bae300..8c581bf27 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -697,8 +697,11 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, def _maybe_put(x): if isinstance(x, np.ndarray): return dispatch._put_x( - x, jax.sharding.SingleDeviceSharding(jax.devices('cpu')[0]), - shaped_abstractify(x), False) + x, + jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]), + shaped_abstractify(x), + False, + ) else: return x