mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use the proto serialization of OpShardings if there are many devices.
Protocol buffers are faster to parse than HLO text. PiperOrigin-RevId: 522643659
This commit is contained in:
parent
830d41d5f8
commit
27c9dcf461
@ -1450,7 +1450,13 @@ def set_sharding(op, sharding_proto: xc.OpSharding):
|
||||
|
||||
|
||||
def get_sharding_attr(sharding_proto: xc.OpSharding):
|
||||
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
|
||||
# If there are very large numbers of devices, use the proto representation.
|
||||
# The MHLO to HLO conversion supports both, and the proto representation is
|
||||
# more compact.
|
||||
if len(sharding_proto.tile_assignment_devices) > 100:
|
||||
return ir.StringAttr.get(sharding_proto.SerializeToString())
|
||||
else:
|
||||
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
|
||||
|
||||
|
||||
# MLIR lowerings for lax primitives
|
||||
|
Loading…
x
Reference in New Issue
Block a user