mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Remove unordered_effects
from lower_jaxpr_to_module
since it is unused
PiperOrigin-RevId: 524139972
This commit is contained in:
parent
468d7720bf
commit
10c4766f6c
@ -542,14 +542,11 @@ def xla_computation(fun: Callable,
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(jaxpr.effects))
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
f"xla_computation_{fun_name}",
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=platform,
|
||||
|
@ -540,7 +540,6 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
unordered_effects: List[core.Effect],
|
||||
ordered_effects: List[core.Effect],
|
||||
backend_or_name: Optional[Union[str, xb.XlaBackend]],
|
||||
platform: str,
|
||||
@ -1773,7 +1772,7 @@ def build_xla_computation_helper(
|
||||
if closed_jaxpr.effects:
|
||||
raise NotImplementedError
|
||||
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
|
||||
backend_or_name=backend_or_name, unordered_effects=[], ordered_effects=[],
|
||||
backend_or_name=backend_or_name, ordered_effects=[],
|
||||
name_stack=source_info_util.NameStack(),
|
||||
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
|
||||
axis_context=axis_context, platform=platform)
|
||||
|
@ -878,7 +878,6 @@ def lower_parallel_callable(
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
unordered_effects,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
@ -1918,13 +1917,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in closed_jaxpr.effects):
|
||||
raise ValueError("Ordered effects are not supported for more than 1 device.")
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
unordered_effects,
|
||||
ordered_effects,
|
||||
backend,
|
||||
# Optionally, override the lowering platform
|
||||
@ -1943,6 +1939,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
return (module, keepalive, host_callbacks, unordered_effects,
|
||||
ordered_effects, nreps, tuple_args)
|
||||
|
||||
@ -2256,7 +2254,6 @@ def lower_mesh_computation(
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
unordered_effects,
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
|
Loading…
x
Reference in New Issue
Block a user