Remove implicit sharding annotation for tpu custom call.

PiperOrigin-RevId: 691876343
This commit is contained in:
jax authors 2024-10-31 11:29:27 -07:00
parent 8296f6e0ba
commit c758373b9c

View File

@ -204,9 +204,6 @@ class CustomCallBackendConfig:
if i + 1 != len(self.flags):
config.write(b",")
config.write(b"]")
# Prevent the compiler from sharding the custom call beyond what Mosaic does
# based on user annotations
config.write(b', "implicit_sharding": {"type": "MANUAL"}')
config.write(b"}")
return config.getvalue()