Use set equality operators instead of intersection because I didn't know set had equality operators.

PiperOrigin-RevId: 530688786
This commit is contained in:
Yash Katariya 2023-05-09 12:55:10 -07:00 committed by jax authors
parent 68ba54241c
commit 2694bf6207

View File

@ -597,8 +597,7 @@ def _mcjax_reshard(x, target_sharding):
if inp_sharding._device_assignment == target_sharding._device_assignment:
return api.jit(_identity_fn, out_shardings=target_sharding)(x)
if len(inp_sharding.device_set.intersection(
target_sharding.device_set)) != len(target_sharding.device_set):
if inp_sharding.device_set != target_sharding.device_set:
inp_ids = [d.id for d in inp_sharding._device_assignment]
inp_plat = inp_sharding._device_assignment[0].platform.upper()
target_ids = [d.id for d in target_sharding._device_assignment]