mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
#sdy dynamically choose which custom_partitioning
API to use based on the current
value of the `use_shardy_partitioner` feature flag. Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that. PiperOrigin-RevId: 715293018
This commit is contained in:
parent
4f2f5fa53a
commit
74e912c3c0
@ -495,15 +495,25 @@ class custom_partitioning:
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
propagate_user_sharding = None
|
||||
infer_sharding_from_operands = None
|
||||
sharding_rule = None
|
||||
if config.use_shardy_partitioner.value:
|
||||
sharding_rule = self.sharding_rule
|
||||
else:
|
||||
propagate_user_sharding = self.propagate_user_sharding
|
||||
infer_sharding_from_operands = self.infer_sharding_from_operands
|
||||
|
||||
out_flat = custom_partitioning_p.bind(
|
||||
*consts,
|
||||
*args_flat,
|
||||
call=closed_call,
|
||||
partition=self.partition,
|
||||
propagate_user_sharding=self.propagate_user_sharding,
|
||||
infer_sharding_from_operands=self.infer_sharding_from_operands,
|
||||
propagate_user_sharding=propagate_user_sharding,
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
decode_shardings=self.decode_shardings,
|
||||
sharding_rule=self.sharding_rule,
|
||||
sharding_rule=sharding_rule,
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree(),
|
||||
static_args=static_args
|
||||
|
Loading…
x
Reference in New Issue
Block a user