mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Use contextlib.nullcontext
instead of trivial_ctx
I removed `trivial_ctx` from the public `jax.interpreters.partial_eval` submodule without going through a deprecation cycle, because it is highly unlikely anyone is using it. PiperOrigin-RevId: 744645764
This commit is contained in:
parent
90cfa99a68
commit
245194ffa1
@ -15,7 +15,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable, Sequence, Hashable
|
||||
from contextlib import contextmanager
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import operator as op
|
||||
@ -1236,14 +1236,12 @@ def _default_res_aval_updater(
|
||||
params: dict[str, Any], aval: AbstractValue) -> AbstractValue:
|
||||
return aval
|
||||
|
||||
@contextmanager
|
||||
def trivial_ctx(_): yield
|
||||
|
||||
def call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str, params_updater: ParamsUpdater,
|
||||
saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
ctx = trivial_ctx,
|
||||
ctx = contextlib.nullcontext,
|
||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||
jaxpr = eqn.params[jaxpr_param_name]
|
||||
with ctx(eqn.params):
|
||||
|
@ -81,7 +81,6 @@ from jax._src.interpreters.partial_eval import (
|
||||
trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits,
|
||||
trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd,
|
||||
tracers_to_jaxpr as tracers_to_jaxpr,
|
||||
trivial_ctx as trivial_ctx,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user