diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0a8e3b782..532eb0f80 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections import namedtuple from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager +import contextlib from functools import partial import itertools as it import operator as op @@ -1236,14 +1236,12 @@ def _default_res_aval_updater( params: dict[str, Any], aval: AbstractValue) -> AbstractValue: return aval -@contextmanager -def trivial_ctx(_): yield def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx = trivial_ctx, + ctx = contextlib.nullcontext, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b546d774a..a2d988f6b 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -81,7 +81,6 @@ from jax._src.interpreters.partial_eval import ( trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, - trivial_ctx as trivial_ctx, )