Merge pull request #26150 from jreiffers:main

PiperOrigin-RevId: 721400896
This commit is contained in:
jax authors 2025-01-30 08:32:20 -08:00
commit 1003ba93c3

View File

@ -455,11 +455,6 @@ class custom_partitioning:
def def_partition(self, partition, infer_sharding_from_operands=None,
propagate_user_sharding=None, decode_shardings=True,
sharding_rule=None):
if config.use_shardy_partitioner.value:
infer_sharding_from_operands = None
propagate_user_sharding = None
else:
sharding_rule = None
self.partition = partition
self.propagate_user_sharding = propagate_user_sharding
self.infer_sharding_from_operands = infer_sharding_from_operands