diff --git a/jax/_src/api.py b/jax/_src/api.py index 3f1069390..26860d85a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -542,14 +542,11 @@ def xla_computation(fun: Callable, jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr)) - unordered_effects = list( - effects.ordered_effects.filter_not_in(jaxpr.effects)) ordered_effects = list( effects.ordered_effects.filter_in(jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( f"xla_computation_{fun_name}", core.ClosedJaxpr(jaxpr, consts), - unordered_effects=unordered_effects, ordered_effects=ordered_effects, backend_or_name=backend, platform=platform, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 009d13a0c..670807b65 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -540,7 +540,6 @@ _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, - unordered_effects: List[core.Effect], ordered_effects: List[core.Effect], backend_or_name: Optional[Union[str, xb.XlaBackend]], platform: str, @@ -1773,7 +1772,7 @@ def build_xla_computation_helper( if closed_jaxpr.effects: raise NotImplementedError lowering_result = lower_jaxpr_to_module(name, closed_jaxpr, - backend_or_name=backend_or_name, unordered_effects=[], ordered_effects=[], + backend_or_name=backend_or_name, ordered_effects=[], name_stack=source_info_util.NameStack(), donated_args=[False] * len(closed_jaxpr.jaxpr.invars), axis_context=axis_context, platform=platform) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 795b50f20..d0c7cd0ae 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -878,7 +878,6 @@ def lower_parallel_callable( lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, - unordered_effects, ordered_effects, backend, lowering_platform or backend.platform, @@ -1918,13 +1917,10 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, if any(effects.ordered_effects.contains(eff) for eff in closed_jaxpr.effects): raise ValueError("Ordered effects are not supported for more than 1 device.") - unordered_effects = list( - effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, - unordered_effects, ordered_effects, backend, # Optionally, override the lowering platform @@ -1943,6 +1939,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, lowering_result.module, lowering_result.keepalive, lowering_result.host_callbacks) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) + unordered_effects = list( + effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) return (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args) @@ -2256,7 +2254,6 @@ def lower_mesh_computation( lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, - unordered_effects, ordered_effects, backend, lowering_platform or backend.platform,