mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
DCE as early as possible so that committed
is not dependent on DCE's vars
PiperOrigin-RevId: 521879918
This commit is contained in:
parent
9095faaeb0
commit
ffa9d018d6
@ -666,7 +666,6 @@ def make_sharded_device_array(
|
||||
aval.shape, sharding, device_buffers) # type: ignore
|
||||
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ShardedDeviceArray = Any
|
||||
else:
|
||||
@ -2365,6 +2364,22 @@ def lower_sharding_computation(
|
||||
global_out_avals = fun_or_jaxpr.out_avals
|
||||
consts = fun_or_jaxpr.consts
|
||||
|
||||
if (keep_unused or
|
||||
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
||||
for a in global_in_avals)):
|
||||
kept_var_idx = set(range(len(global_in_avals)))
|
||||
else:
|
||||
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
|
||||
del kept_const_idx
|
||||
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
|
||||
kept_outputs = [True] * len(global_out_avals)
|
||||
|
||||
if _is_unspecified(out_shardings):
|
||||
@ -2383,9 +2398,6 @@ def lower_sharding_computation(
|
||||
for js, source_info in jaxpr_sharding]),
|
||||
devices_from_context)
|
||||
|
||||
# TODO(yashkatariya): Make this logic work after DCE because there can be
|
||||
# equations inside the jaxpr that don't affect the output so whether the
|
||||
# output(s) are committed or not should not depend on it.
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
@ -2402,17 +2414,6 @@ def lower_sharding_computation(
|
||||
"Argument mapping: %s.",
|
||||
fun_name, global_in_avals, in_shardings)
|
||||
|
||||
if keep_unused or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
||||
for a in global_in_avals):
|
||||
kept_var_idx = set(range(len(global_in_avals)))
|
||||
else:
|
||||
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
|
||||
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
||||
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
|
||||
del kept_const_idx
|
||||
|
||||
local_device_assignment = [d for d in device_assignment
|
||||
if d.process_index == d.client.process_index()]
|
||||
if len(device_assignment) != len(local_device_assignment):
|
||||
@ -2438,7 +2439,6 @@ def lower_sharding_computation(
|
||||
"`with jax.spmd_mode('allow_all'):` context manager.")
|
||||
|
||||
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
@ -2498,7 +2498,6 @@ def lower_sharding_computation(
|
||||
axis_env = xla.AxisEnv(nreps, (), ())
|
||||
axis_ctx = mlir.ReplicaAxisContext(axis_env)
|
||||
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
|
||||
if len(device_assignment) > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user