#sdy dynamically choose which custom_partitioning API to use based on the current

value of the `use_shardy_partitioner` feature flag.

Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that.

PiperOrigin-RevId: 715293018
This commit is contained in:
Bart Chrzaszcz 2025-01-14 02:11:19 -08:00 committed by jax authors
parent 4f2f5fa53a
commit 74e912c3c0

View File

@ -495,15 +495,25 @@ class custom_partitioning:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
propagate_user_sharding = None
infer_sharding_from_operands = None
sharding_rule = None
if config.use_shardy_partitioner.value:
sharding_rule = self.sharding_rule
else:
propagate_user_sharding = self.propagate_user_sharding
infer_sharding_from_operands = self.infer_sharding_from_operands
out_flat = custom_partitioning_p.bind(
*consts,
*args_flat,
call=closed_call,
partition=self.partition,
propagate_user_sharding=self.propagate_user_sharding,
infer_sharding_from_operands=self.infer_sharding_from_operands,
propagate_user_sharding=propagate_user_sharding,
infer_sharding_from_operands=infer_sharding_from_operands,
decode_shardings=self.decode_shardings,
sharding_rule=self.sharding_rule,
sharding_rule=sharding_rule,
in_tree=in_tree,
out_tree=out_tree(),
static_args=static_args