Use contextlib.nullcontext instead of trivial_ctx

I removed `trivial_ctx` from the public `jax.interpreters.partial_eval`
submodule without going through a deprecation cycle, because it is highly
unlikely anyone is using it.

PiperOrigin-RevId: 744645764
This commit is contained in:
Sergei Lebedev 2025-04-07 02:40:01 -07:00 committed by jax authors
parent 90cfa99a68
commit 245194ffa1
2 changed files with 2 additions and 5 deletions

View File

@ -15,7 +15,7 @@ from __future__ import annotations
from collections import namedtuple
from collections.abc import Callable, Sequence, Hashable
from contextlib import contextmanager
import contextlib
from functools import partial
import itertools as it
import operator as op
@ -1236,14 +1236,12 @@ def _default_res_aval_updater(
params: dict[str, Any], aval: AbstractValue) -> AbstractValue:
return aval
@contextmanager
def trivial_ctx(_): yield
def call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
ctx = trivial_ctx,
ctx = contextlib.nullcontext,
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
jaxpr = eqn.params[jaxpr_param_name]
with ctx(eqn.params):

View File

@ -81,7 +81,6 @@ from jax._src.interpreters.partial_eval import (
trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits,
trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd,
tracers_to_jaxpr as tracers_to_jaxpr,
trivial_ctx as trivial_ctx,
)