DCE as early as possible so that committed is not dependent on DCE's vars

PiperOrigin-RevId: 521879918
This commit is contained in:
Yash Katariya 2023-04-04 15:20:32 -07:00 committed by jax authors
parent 9095faaeb0
commit ffa9d018d6

View File

@ -666,7 +666,6 @@ def make_sharded_device_array(
aval.shape, sharding, device_buffers) # type: ignore
if TYPE_CHECKING:
ShardedDeviceArray = Any
else:
@ -2365,6 +2364,22 @@ def lower_sharding_computation(
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
if (keep_unused or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals)):
kept_var_idx = set(range(len(global_in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
jaxpr = closed_jaxpr.jaxpr
kept_outputs = [True] * len(global_out_avals)
if _is_unspecified(out_shardings):
@ -2383,9 +2398,6 @@ def lower_sharding_computation(
for js, source_info in jaxpr_sharding]),
devices_from_context)
# TODO(yashkatariya): Make this logic work after DCE because there can be
# equations inside the jaxpr that don't affect the output so whether the
# output(s) are committed or not should not depend on it.
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
@ -2402,17 +2414,6 @@ def lower_sharding_computation(
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)
if keep_unused or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals):
kept_var_idx = set(range(len(global_in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx
local_device_assignment = [d for d in device_assignment
if d.process_index == d.client.process_index()]
if len(device_assignment) != len(local_device_assignment):
@ -2438,7 +2439,6 @@ def lower_sharding_computation(
"`with jax.spmd_mode('allow_all'):` context manager.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
@ -2498,7 +2498,6 @@ def lower_sharding_computation(
axis_env = xla.AxisEnv(nreps, (), ())
axis_ctx = mlir.ReplicaAxisContext(axis_env)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module_name = f"{api_name}_{fun_name}"
if len(device_assignment) > 1: