Remove unordered_effects from lower_jaxpr_to_module since it is unused

PiperOrigin-RevId: 524139972
This commit is contained in:
Yash Katariya 2023-04-13 16:57:03 -07:00 committed by jax authors
parent 468d7720bf
commit 10c4766f6c
3 changed files with 3 additions and 10 deletions

View File

@ -542,14 +542,11 @@ def xla_computation(fun: Callable,
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
unordered_effects = list(
effects.ordered_effects.filter_not_in(jaxpr.effects))
ordered_effects = list(
effects.ordered_effects.filter_in(jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=platform,

View File

@ -540,7 +540,6 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
@ -1773,7 +1772,7 @@ def build_xla_computation_helper(
if closed_jaxpr.effects:
raise NotImplementedError
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
backend_or_name=backend_or_name, unordered_effects=[], ordered_effects=[],
backend_or_name=backend_or_name, ordered_effects=[],
name_stack=source_info_util.NameStack(),
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
axis_context=axis_context, platform=platform)

View File

@ -878,7 +878,6 @@ def lower_parallel_callable(
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
unordered_effects,
ordered_effects,
backend,
lowering_platform or backend.platform,
@ -1918,13 +1917,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
if any(effects.ordered_effects.contains(eff) for eff
in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
unordered_effects,
ordered_effects,
backend,
# Optionally, override the lowering platform
@ -1943,6 +1939,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
return (module, keepalive, host_callbacks, unordered_effects,
ordered_effects, nreps, tuple_args)
@ -2256,7 +2254,6 @@ def lower_mesh_computation(
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
unordered_effects,
ordered_effects,
backend,
lowering_platform or backend.platform,