mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use set equality operators instead of intersection because I didn't know set had equality operators.
PiperOrigin-RevId: 530688786
This commit is contained in:
parent
68ba54241c
commit
2694bf6207
@ -597,8 +597,7 @@ def _mcjax_reshard(x, target_sharding):
|
||||
if inp_sharding._device_assignment == target_sharding._device_assignment:
|
||||
return api.jit(_identity_fn, out_shardings=target_sharding)(x)
|
||||
|
||||
if len(inp_sharding.device_set.intersection(
|
||||
target_sharding.device_set)) != len(target_sharding.device_set):
|
||||
if inp_sharding.device_set != target_sharding.device_set:
|
||||
inp_ids = [d.id for d in inp_sharding._device_assignment]
|
||||
inp_plat = inp_sharding._device_assignment[0].platform.upper()
|
||||
target_ids = [d.id for d in target_sharding._device_assignment]
|
||||
|
Loading…
x
Reference in New Issue
Block a user