mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Don't apply shardy config during def_partition.
PR #25834 intended to dynamically choose the the partitioner API, but it still applies the configuration value too early (it should only be applied in __call__, not in def_partition and __call__).
This commit is contained in:
parent
e332b94f19
commit
55d891c5bf
@ -454,11 +454,6 @@ class custom_partitioning:
|
||||
def def_partition(self, partition, infer_sharding_from_operands,
|
||||
propagate_user_sharding=None, decode_shardings=True,
|
||||
sharding_rule=None):
|
||||
if config.use_shardy_partitioner.value:
|
||||
infer_sharding_from_operands = None
|
||||
propagate_user_sharding = None
|
||||
else:
|
||||
sharding_rule = None
|
||||
self.partition = partition
|
||||
self.propagate_user_sharding = propagate_user_sharding
|
||||
self.infer_sharding_from_operands = infer_sharding_from_operands
|
||||
|
Loading…
x
Reference in New Issue
Block a user