Use the proto serialization of OpShardings if there are many devices.

Protocol buffers are faster to parse than HLO text.

PiperOrigin-RevId: 522643659
This commit is contained in:
Peter Hawkins 2023-04-07 11:27:36 -07:00 committed by jax authors
parent 830d41d5f8
commit 27c9dcf461

View File

@ -1450,7 +1450,13 @@ def set_sharding(op, sharding_proto: xc.OpSharding):
def get_sharding_attr(sharding_proto: xc.OpSharding):
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
# If there are very large numbers of devices, use the proto representation.
# The MHLO to HLO conversion supports both, and the proto representation is
# more compact.
if len(sharding_proto.tile_assignment_devices) > 100:
return ir.StringAttr.get(sharding_proto.SerializeToString())
else:
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
# MLIR lowerings for lax primitives