mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove implicit sharding annotation for tpu custom call.
PiperOrigin-RevId: 691876343
This commit is contained in:
parent
8296f6e0ba
commit
c758373b9c
@ -204,9 +204,6 @@ class CustomCallBackendConfig:
|
||||
if i + 1 != len(self.flags):
|
||||
config.write(b",")
|
||||
config.write(b"]")
|
||||
# Prevent the compiler from sharding the custom call beyond what Mosaic does
|
||||
# based on user annotations
|
||||
config.write(b', "implicit_sharding": {"type": "MANUAL"}')
|
||||
config.write(b"}")
|
||||
return config.getvalue()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user