diff --git a/jax/_src/api.py b/jax/_src/api.py index 2054325a2..a6b0ba3ae 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3290,19 +3290,7 @@ def named_call( if name is None: name = fun.__name__ - _, in_tree = tree_flatten(()) - - if config.jax_experimental_name_stack: - return source_info_util.extend_name_stack(name)(fun) - - @functools.wraps(fun) - def named_call_f(*args, **kwargs): - lu_f = lu.wrap_init(lambda: fun(*args, **kwargs)) - flat_f, out_tree = flatten_fun_nokwargs(lu_f, in_tree) - out_flat = core.named_call_p.bind(flat_f, name=name) - return tree_unflatten(out_tree(), out_flat) - - return named_call_f + return source_info_util.extend_name_stack(name)(fun) @contextmanager def named_scope( diff --git a/jax/_src/config.py b/jax/_src/config.py index e89bc1c07..36c82d504 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1001,11 +1001,6 @@ config.define_bool_state( update_thread_local_hook=lambda val: \ update_thread_local_jit_state(dynamic_shapes=val)) -config.define_bool_state( - name='jax_experimental_name_stack', - default=True, - help='Enable using the context manager-based name stack.') - # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. config.define_bool_state( diff --git a/jax/_src/util.py b/jax/_src/util.py index c05723b1e..fc4c01806 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -333,21 +333,16 @@ def wrap_name(name, transform_name): return transform_name + '(' + name + ')' def new_name_stack(name: str = ''): - if config.jax_experimental_name_stack: - from jax._src import source_info_util - name_stack = source_info_util.NameStack() - if name: - name_stack = name_stack.extend(name) - return name_stack - return name + '/' + from jax._src import source_info_util + name_stack = source_info_util.NameStack() + if name: + name_stack = name_stack.extend(name) + return name_stack def extend_name_stack(stack, name: str): - if config.jax_experimental_name_stack: - from jax._src import source_info_util - assert isinstance(stack, source_info_util.NameStack), stack - return stack.extend(name) - assert isinstance(stack, str) - return stack + name + '/' + from jax._src import source_info_util + assert isinstance(stack, source_info_util.NameStack), stack + return stack.extend(name) def canonicalize_axis(axis, num_dims) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" diff --git a/jax/core.py b/jax/core.py index 617bd69d4..3dde1c257 100644 --- a/jax/core.py +++ b/jax/core.py @@ -2073,9 +2073,6 @@ call_p: CallPrimitive = CallPrimitive('call') call = call_p.bind call_p.def_impl(call_impl) -named_call_p: CallPrimitive = CallPrimitive('named_call') -named_call_p.def_impl(call_impl) - class ClosedCallPrimitive(CallPrimitive): def get_bind_params(self, params): diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index b74924bf4..5788aa961 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1651,16 +1651,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], # compilation to XLA, which does not use those parameters. bwd="illegal param", out_trees="illegal param"))) - elif eqn.primitive is core.named_call_p: - call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), - ))) elif eqn.primitive is pjit.pjit_p: jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"]) eqns.append( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 94faf0650..77a7535b1 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -187,9 +187,7 @@ class _ThreadLocalState(threading.local): _thread_local_state = _ThreadLocalState() def _get_current_name_stack() -> Union[NameStack, str]: - if config.jax_experimental_name_stack: - return source_info_util.current_name_stack() - return _thread_local_state.name_stack + return source_info_util.current_name_stack() @contextlib.contextmanager def inside_call_tf(): @@ -530,24 +528,12 @@ def make_custom_gradient_fn_tf( @contextlib.contextmanager def _extended_name_stack(extra_name_stack: Optional[str]): - if config.jax_experimental_name_stack: - name_ctx = (source_info_util.extend_name_stack(extra_name_stack) - if extra_name_stack - else contextlib.nullcontext()) - with name_ctx: - yield - return - prev_name_stack = _thread_local_state.name_stack - if extra_name_stack: - if not prev_name_stack: - _thread_local_state.name_stack = extra_name_stack - else: - _thread_local_state.name_stack = util.extend_name_stack( - _thread_local_state.name_stack, extra_name_stack) - try: + name_ctx = (source_info_util.extend_name_stack(extra_name_stack) + if extra_name_stack + else contextlib.nullcontext()) + with name_ctx: yield - finally: - _thread_local_state.name_stack = prev_name_stack + return def _interpret_fun_jax( @@ -1050,20 +1036,14 @@ class TensorFlowTrace(core.Trace): return impl(*args_tf, **params) current_name_stack = _get_current_name_stack() - if config.jax_experimental_name_stack: - # We don't use `str(name_stack)` because it uses parentheses for - # transformations, which aren't allowed in `name_scope`. - scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr] - else: - scope = str(current_name_stack) + # We don't use `str(name_stack)` because it uses parentheses for + # transformations, which aren't allowed in `name_scope`. + scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr] # We need to add a '/' to the name stack string to force `tf.name_scope` # to interpret it as an absolute scope, not a relative scope. scope = scope + '/' - name_scope = ( - tf.name_scope(_sanitize_scope_name(scope)) if - config.jax_experimental_name_stack else contextlib.nullcontext()) - with name_scope: + with tf.name_scope(_sanitize_scope_name(scope)): if _thread_local_state.include_xla_op_metadata: op_metadata = xla.make_op_metadata(primitive, params, name_stack=current_name_stack, @@ -1109,17 +1089,11 @@ class TensorFlowTrace(core.Trace): avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) interpreted_fun = _interpret_subtrace(fun, self.main, avals) extra_name_stack = None - if call_primitive == core.named_call_p: - extra_name_stack = util.wrap_name(params["name"], "named") - elif call_primitive == xla.xla_call_p: + if call_primitive == xla.xla_call_p: extra_name_stack = util.wrap_name(params["name"], "jit") with _extended_name_stack(extra_name_stack): with core.new_sublevel(): - if call_primitive == core.named_call_p: - with tf.name_scope(_sanitize_scope_name(params["name"])): - vals_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \ - interpreted_fun.call_wrapped(*vals) - elif call_primitive == xla.xla_call_p: + if call_primitive == xla.xla_call_p: if _WRAP_JAX_JIT_WITH_TF_FUNCTION: # Make a nested tf.function(jit_compile=True) store_tf_res_avals: Sequence[core.ShapedArray] = [] @@ -1208,7 +1182,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): # Call primitives are inlined -for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, maps.xmap_p]: +for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]: tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) # Primitives that are not yet implemented must be explicitly declared here. @@ -2615,11 +2589,10 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], for jaxpr in branches for i, jaxpr in enumerate(branches) ] - if config.jax_experimental_name_stack: - # Same name stack as XLA translation of cond_p - # Note: extend_name_stack is a contextmanager, which is callable as a decorator. - branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type] - branches_tf)) + # Same name stack as XLA translation of cond_p + # Note: extend_name_stack is a contextmanager, which is callable as a decorator. + branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type] + branches_tf)) return tf.switch_case(index, branches_tf) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 3a83d16f7..5412eee66 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -771,12 +771,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): f_tf_hlo = self.TfToHlo(f_tf, arg) if jax.config.jax_remat_opt_barrier: self.assertRegex(f_tf_hlo, r"opt-barrier") - elif config.jax_experimental_name_stack: - self.assertRegex(f_tf_hlo, - r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') else: self.assertRegex(f_tf_hlo, - r'switch_case/indexed_case/Sin') + r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') def test_remat_free_var(self): def f(x): diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 799333c14..1f90c8e22 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -319,8 +319,6 @@ class JVPTrace(Trace): which_nz = [ type(t) is not Zero for t in tangents] tangents = [t if type(t) is not Zero else None for t in tangents] args, in_tree = tree_flatten((primals, tangents)) - if 'name' in params and not config.jax_experimental_name_stack: - params = dict(params, name=wrap_name(params['name'], 'jvp')) f_jvp = jvp_subtrace(f, self.main) f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp) if isinstance(call_primitive, core.MapPrimitive): @@ -609,8 +607,6 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) - if 'name' in params and not config.jax_experimental_name_stack: - params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: params = update_params(params, map(is_undefined_primal, args), @@ -627,8 +623,6 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat) primitive_transposes[core.call_p] = partial(call_transpose, call_p) -primitive_transposes[core.named_call_p] = \ - partial(call_transpose, core.named_call_p) def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes): diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 9d8572f23..ed98c19f0 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -31,10 +31,9 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten, from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p, Zero) from jax import linear_util as lu -from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name, - split_list, canonicalize_axis, moveaxis, - as_hashable_function, curry, memoize, - weakref_lru_cache) +from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, + canonicalize_axis, moveaxis, as_hashable_function, + curry, memoize, weakref_lru_cache) from jax.interpreters import partial_eval as pe Array = Any @@ -352,10 +351,7 @@ class BatchTrace(Trace): def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results - if config.jax_experimental_name_stack: - params = dict(params, name=params.get('name', f.__name__)) - else: - params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) + params = dict(params, name=params.get('name', f.__name__)) vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 2b3c67cc6..90cedfc2a 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -309,13 +309,9 @@ register_constant_handler( def _source_info_to_location( primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, - name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location: - if config.jax_experimental_name_stack: - eqn_str = (f'{str(source_info.name_stack)}/' - f'{core.str_eqn_compact(primitive.name, params)}') - else: - assert isinstance(name_stack, str) - eqn_str = name_stack + core.str_eqn_compact(primitive.name, params) + name_stack: source_info_util.NameStack) -> ir.Location: + eqn_str = (f'{str(source_info.name_stack)}/' + f'{core.str_eqn_compact(primitive.name, params)}') frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() @@ -328,8 +324,6 @@ def _source_info_to_location( # Translation rules -NameStack = Union[str, source_info_util.NameStack] - def make_ir_context() -> ir.Context: """Creates an MLIR context suitable for JAX IR.""" context = ir.Context() @@ -414,7 +408,7 @@ class ModuleContext: backend_or_name: Optional[Union[str, xb.XlaBackend]] platform: str axis_context: AxisContext - name_stack: NameStack + name_stack: source_info_util.NameStack keepalives: List[Any] channel_iterator: Iterator[int] host_callbacks: List[Any] @@ -432,7 +426,7 @@ class ModuleContext: backend_or_name: Optional[Union[str, xb.XlaBackend]], platform: str, axis_context: AxisContext, - name_stack: NameStack, + name_stack: source_info_util.NameStack, keepalives: List[Any], channel_iterator: Iterator[int], host_callbacks: List[Any], @@ -583,7 +577,7 @@ def lower_jaxpr_to_module( backend_or_name: Optional[Union[str, xb.XlaBackend]], platform: str, axis_context: AxisContext, - name_stack: NameStack, + name_stack: source_info_util.NameStack, donated_args: Sequence[bool], replicated_args: Optional[Sequence[bool]] = None, arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, @@ -1007,14 +1001,11 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, map(write, jaxpr.invars, args) for eqn in jaxpr.eqns: in_nodes = map(read, eqn.invars) - if config.jax_experimental_name_stack: - assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) - source_info = eqn.source_info.replace( - name_stack=ctx.name_stack + eqn.source_info.name_stack) - else: - source_info = eqn.source_info + assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) + source_info = eqn.source_info.replace( + name_stack=ctx.name_stack + eqn.source_info.name_stack) loc = _source_info_to_location(eqn.primitive, eqn.params, source_info, - name_stack=ctx.name_stack) + ctx.name_stack) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in _platform_specific_lowerings[ctx.platform]: rule = _platform_specific_lowerings[ctx.platform][eqn.primitive] @@ -1029,8 +1020,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, f"MLIR translation rule for primitive '{eqn.primitive.name}' not " f"found for platform {ctx.platform}") - eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if - config.jax_experimental_name_stack else ctx) + eqn_ctx = ctx.replace(name_stack=source_info.name_stack) effects = [eff for eff in eqn.effects if eff in core.ordered_effects] tokens_in = tokens.subset(effects) avals_in = map(aval, eqn.invars) @@ -1160,21 +1150,16 @@ def _xla_call_lower(ctx, *args, register_lowering(xla.xla_call_p, _xla_call_lower) -def _named_call_lowering(ctx, *args, name, backend=None, - call_jaxpr): +def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): out_nodes, tokens = _call_lowering( name, name, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args) ctx.set_tokens_out(tokens) return out_nodes -register_lowering(core.named_call_p, _named_call_lowering) -register_lowering(core.call_p, partial(_named_call_lowering, name="core_call")) +register_lowering(core.call_p, partial(_core_call_lowering, name="core_call")) register_lowering(core.closed_call_p, - partial(_named_call_lowering, name="core_closed_call")) - -register_lowering(core.closed_call_p, - partial(_named_call_lowering, name="core_closed_call")) + partial(_core_call_lowering, name="core_closed_call")) def full_like_aval(value, aval: core.ShapedArray) -> ir.Value: diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 5854b9c0e..d51f76638 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1291,9 +1291,6 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \ lambda _, __, ___, ____, _____, x, y: (x, y)) partial_eval_jaxpr_custom_rules[core.closed_call_p] = \ partial(closed_call_partial_eval_custom_rule, 'call_jaxpr') -partial_eval_jaxpr_custom_rules[core.named_call_p] = \ - partial(call_partial_eval_custom_rule, 'call_jaxpr', - lambda _, __, ___, ____, _____, x, y: (x, y)) def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]: @@ -1394,7 +1391,6 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) return used_inputs, new_eqn dce_rules[core.call_p] = dce_jaxpr_call_rule -dce_rules[core.named_call_p] = dce_jaxpr_call_rule def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index ffe04fa2d..aefc882a0 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -104,11 +104,8 @@ def make_op_metadata(primitive: core.Primitive, source_info: source_info_util.SourceInfo, name_stack: Union[str, source_info_util.NameStack] = "", ) -> xc.OpMetadata: - if config.jax_experimental_name_stack: - eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params) - else: - assert isinstance(name_stack, str) - eqn_str = name_stack + str_eqn_compact(primitive.name, params) + eqn_str = (str(source_info.name_stack) + '/' + + str_eqn_compact(primitive.name, params)) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) return xc.OpMetadata( diff --git a/tests/api_test.py b/tests/api_test.py index 934b83c73..44f426122 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9286,12 +9286,9 @@ class NamedCallTest(jtu.JaxTestCase): return my_test_function(x) c = jax.xla_computation(f)(2) - if config.jax_experimental_name_stack: - print_opts = xla_client._xla.HloPrintOptions.short_parsable() - print_opts.print_metadata = True - hlo_text = c.as_hlo_module().to_string(print_opts) - else: - hlo_text = c.as_hlo_text() + print_opts = xla_client._xla.HloPrintOptions.short_parsable() + print_opts.print_metadata = True + hlo_text = c.as_hlo_module().to_string(print_opts) self.assertIn("my_test_function", hlo_text) def test_non_jaxtype_arg(self): diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index 5dbbb35d1..05e2ce40b 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -18,7 +18,6 @@ from absl.testing import absltest import jax.numpy as jnp from jax.tools import jax_to_ir from jax._src import test_util as jtu -from jax.config import config try: import tensorflow as tf @@ -90,14 +89,9 @@ class JaxToIRTest(absltest.TestCase): ]) # Check that tf debug txt contains a broadcast, add, and multiply. - if config.jax_experimental_name_stack: - self.assertIn('BroadcastTo', tf_text) - self.assertIn('AddV2', tf_text) - self.assertIn('Mul', tf_text) - else: - self.assertIn('name: "BroadcastTo"', tf_text) - self.assertIn('name: "AddV2"', tf_text) - self.assertIn('name: "Mul"', tf_text) + self.assertIn('BroadcastTo', tf_text) + self.assertIn('AddV2', tf_text) + self.assertIn('Mul', tf_text) # Check that we can re-import our graphdef. gdef = tf.compat.v1.GraphDef() diff --git a/tests/metadata_test.py b/tests/metadata_test.py index f55bf82de..bc2632c09 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -63,14 +63,9 @@ class MetadataTest(jtu.JaxTestCase): def foo(x): return jnp.sin(x) hlo = jax.xla_computation(jax.grad(foo))(1.).get_hlo_module().to_string() - if config.jax_experimental_name_stack: - self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/sin"') - self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/cos"') - self.assertRegex(hlo, 'op_name=".*transpose\\(jvp\\(jit\\(foo\\)\\)\\)/mul"') - else: - self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"') - self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"') - self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\(jvp\\(foo\\)\\)\\)/mul"') + self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/sin"') + self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/cos"') + self.assertRegex(hlo, 'op_name=".*transpose\\(jvp\\(jit\\(foo\\)\\)\\)/mul"') def test_cond_metadata(self): def true_fun(x): diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 912e334ac..acdcefc92 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -33,17 +33,8 @@ def _get_hlo(f): return c.as_hlo_module().to_string(print_opts) return wrapped -class _EnableNameStackTestCase(jtu.JaxTestCase): - def setUp(self): - self.cfg = config._read("jax_experimental_name_stack") - config.update("jax_experimental_name_stack", True) - - def tearDown(self): - config.update("jax_experimental_name_stack", self.cfg) - - -class NameStackTest(_EnableNameStackTestCase): +class NameStackTest(jtu.JaxTestCase): def test_trivial_name_stack(self): @@ -135,7 +126,7 @@ class NameStackTest(_EnableNameStackTestCase): self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar') -class NameStackTransformationTest(_EnableNameStackTestCase): +class NameStackTransformationTest(jtu.JaxTestCase): def test_vmap_should_transform_name_stack(self): @jax.vmap @@ -238,7 +229,7 @@ class NameStackTransformationTest(_EnableNameStackTestCase): self.assertIn('transpose(jvp(foo))/jit(f)/bar/mul', hlo_text) -class NameStackControlFlowTest(_EnableNameStackTestCase): +class NameStackControlFlowTest(jtu.JaxTestCase): def test_while_loop_body_should_not_have_name_stack(self):