mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26150 from jreiffers:main
PiperOrigin-RevId: 721400896
This commit is contained in:
commit
1003ba93c3
@ -455,11 +455,6 @@ class custom_partitioning:
|
||||
def def_partition(self, partition, infer_sharding_from_operands=None,
|
||||
propagate_user_sharding=None, decode_shardings=True,
|
||||
sharding_rule=None):
|
||||
if config.use_shardy_partitioner.value:
|
||||
infer_sharding_from_operands = None
|
||||
propagate_user_sharding = None
|
||||
else:
|
||||
sharding_rule = None
|
||||
self.partition = partition
|
||||
self.propagate_user_sharding = propagate_user_sharding
|
||||
self.infer_sharding_from_operands = infer_sharding_from_operands
|
||||
|
Loading…
x
Reference in New Issue
Block a user