mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Delete jax_experimental_name_stack
flag
PiperOrigin-RevId: 487601864
This commit is contained in:
parent
0ebb6b4215
commit
74b136e62c
@ -3290,20 +3290,8 @@ def named_call(
|
||||
if name is None:
|
||||
name = fun.__name__
|
||||
|
||||
_, in_tree = tree_flatten(())
|
||||
|
||||
if config.jax_experimental_name_stack:
|
||||
return source_info_util.extend_name_stack(name)(fun)
|
||||
|
||||
@functools.wraps(fun)
|
||||
def named_call_f(*args, **kwargs):
|
||||
lu_f = lu.wrap_init(lambda: fun(*args, **kwargs))
|
||||
flat_f, out_tree = flatten_fun_nokwargs(lu_f, in_tree)
|
||||
out_flat = core.named_call_p.bind(flat_f, name=name)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
return named_call_f
|
||||
|
||||
@contextmanager
|
||||
def named_scope(
|
||||
name: str,
|
||||
|
@ -1001,11 +1001,6 @@ config.define_bool_state(
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(dynamic_shapes=val))
|
||||
|
||||
config.define_bool_state(
|
||||
name='jax_experimental_name_stack',
|
||||
default=True,
|
||||
help='Enable using the context manager-based name stack.')
|
||||
|
||||
# This flag is temporary during rollout of the remat barrier.
|
||||
# TODO(parkers): Remove if there are no complaints.
|
||||
config.define_bool_state(
|
||||
|
@ -333,21 +333,16 @@ def wrap_name(name, transform_name):
|
||||
return transform_name + '(' + name + ')'
|
||||
|
||||
def new_name_stack(name: str = ''):
|
||||
if config.jax_experimental_name_stack:
|
||||
from jax._src import source_info_util
|
||||
name_stack = source_info_util.NameStack()
|
||||
if name:
|
||||
name_stack = name_stack.extend(name)
|
||||
return name_stack
|
||||
return name + '/'
|
||||
|
||||
def extend_name_stack(stack, name: str):
|
||||
if config.jax_experimental_name_stack:
|
||||
from jax._src import source_info_util
|
||||
assert isinstance(stack, source_info_util.NameStack), stack
|
||||
return stack.extend(name)
|
||||
assert isinstance(stack, str)
|
||||
return stack + name + '/'
|
||||
|
||||
def canonicalize_axis(axis, num_dims) -> int:
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
|
@ -2073,9 +2073,6 @@ call_p: CallPrimitive = CallPrimitive('call')
|
||||
call = call_p.bind
|
||||
call_p.def_impl(call_impl)
|
||||
|
||||
named_call_p: CallPrimitive = CallPrimitive('named_call')
|
||||
named_call_p.def_impl(call_impl)
|
||||
|
||||
|
||||
class ClosedCallPrimitive(CallPrimitive):
|
||||
def get_bind_params(self, params):
|
||||
|
@ -1651,16 +1651,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
# compilation to XLA, which does not use those parameters.
|
||||
bwd="illegal param",
|
||||
out_trees="illegal param")))
|
||||
elif eqn.primitive is core.named_call_p:
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
eqns.append(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
||||
)))
|
||||
elif eqn.primitive is pjit.pjit_p:
|
||||
jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"])
|
||||
eqns.append(
|
||||
|
@ -187,9 +187,7 @@ class _ThreadLocalState(threading.local):
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
def _get_current_name_stack() -> Union[NameStack, str]:
|
||||
if config.jax_experimental_name_stack:
|
||||
return source_info_util.current_name_stack()
|
||||
return _thread_local_state.name_stack
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inside_call_tf():
|
||||
@ -530,24 +528,12 @@ def make_custom_gradient_fn_tf(
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _extended_name_stack(extra_name_stack: Optional[str]):
|
||||
if config.jax_experimental_name_stack:
|
||||
name_ctx = (source_info_util.extend_name_stack(extra_name_stack)
|
||||
if extra_name_stack
|
||||
else contextlib.nullcontext())
|
||||
with name_ctx:
|
||||
yield
|
||||
return
|
||||
prev_name_stack = _thread_local_state.name_stack
|
||||
if extra_name_stack:
|
||||
if not prev_name_stack:
|
||||
_thread_local_state.name_stack = extra_name_stack
|
||||
else:
|
||||
_thread_local_state.name_stack = util.extend_name_stack(
|
||||
_thread_local_state.name_stack, extra_name_stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_thread_local_state.name_stack = prev_name_stack
|
||||
|
||||
|
||||
def _interpret_fun_jax(
|
||||
@ -1050,20 +1036,14 @@ class TensorFlowTrace(core.Trace):
|
||||
return impl(*args_tf, **params)
|
||||
|
||||
current_name_stack = _get_current_name_stack()
|
||||
if config.jax_experimental_name_stack:
|
||||
# We don't use `str(name_stack)` because it uses parentheses for
|
||||
# transformations, which aren't allowed in `name_scope`.
|
||||
scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr]
|
||||
else:
|
||||
scope = str(current_name_stack)
|
||||
# We need to add a '/' to the name stack string to force `tf.name_scope`
|
||||
# to interpret it as an absolute scope, not a relative scope.
|
||||
scope = scope + '/'
|
||||
name_scope = (
|
||||
tf.name_scope(_sanitize_scope_name(scope)) if
|
||||
config.jax_experimental_name_stack else contextlib.nullcontext())
|
||||
|
||||
with name_scope:
|
||||
with tf.name_scope(_sanitize_scope_name(scope)):
|
||||
if _thread_local_state.include_xla_op_metadata:
|
||||
op_metadata = xla.make_op_metadata(primitive, params,
|
||||
name_stack=current_name_stack,
|
||||
@ -1109,17 +1089,11 @@ class TensorFlowTrace(core.Trace):
|
||||
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
||||
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
|
||||
extra_name_stack = None
|
||||
if call_primitive == core.named_call_p:
|
||||
extra_name_stack = util.wrap_name(params["name"], "named")
|
||||
elif call_primitive == xla.xla_call_p:
|
||||
if call_primitive == xla.xla_call_p:
|
||||
extra_name_stack = util.wrap_name(params["name"], "jit")
|
||||
with _extended_name_stack(extra_name_stack):
|
||||
with core.new_sublevel():
|
||||
if call_primitive == core.named_call_p:
|
||||
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
||||
vals_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||||
interpreted_fun.call_wrapped(*vals)
|
||||
elif call_primitive == xla.xla_call_p:
|
||||
if call_primitive == xla.xla_call_p:
|
||||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||||
# Make a nested tf.function(jit_compile=True)
|
||||
store_tf_res_avals: Sequence[core.ShapedArray] = []
|
||||
@ -1208,7 +1182,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
|
||||
|
||||
# Call primitives are inlined
|
||||
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, maps.xmap_p]:
|
||||
for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]:
|
||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
@ -2615,7 +2589,6 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
|
||||
for jaxpr in branches
|
||||
for i, jaxpr in enumerate(branches)
|
||||
]
|
||||
if config.jax_experimental_name_stack:
|
||||
# Same name stack as XLA translation of cond_p
|
||||
# Note: extend_name_stack is a contextmanager, which is callable as a decorator.
|
||||
branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type]
|
||||
|
@ -771,12 +771,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf_hlo = self.TfToHlo(f_tf, arg)
|
||||
if jax.config.jax_remat_opt_barrier:
|
||||
self.assertRegex(f_tf_hlo, r"opt-barrier")
|
||||
elif config.jax_experimental_name_stack:
|
||||
self.assertRegex(f_tf_hlo,
|
||||
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
|
||||
else:
|
||||
self.assertRegex(f_tf_hlo,
|
||||
r'switch_case/indexed_case/Sin')
|
||||
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
|
||||
|
||||
def test_remat_free_var(self):
|
||||
def f(x):
|
||||
|
@ -319,8 +319,6 @@ class JVPTrace(Trace):
|
||||
which_nz = [ type(t) is not Zero for t in tangents]
|
||||
tangents = [t if type(t) is not Zero else None for t in tangents]
|
||||
args, in_tree = tree_flatten((primals, tangents))
|
||||
if 'name' in params and not config.jax_experimental_name_stack:
|
||||
params = dict(params, name=wrap_name(params['name'], 'jvp'))
|
||||
f_jvp = jvp_subtrace(f, self.main)
|
||||
f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp)
|
||||
if isinstance(call_primitive, core.MapPrimitive):
|
||||
@ -609,8 +607,6 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
|
||||
reduce_axes, False)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
if 'name' in params and not config.jax_experimental_name_stack:
|
||||
params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
params = update_params(params, map(is_undefined_primal, args),
|
||||
@ -627,8 +623,6 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
primitive_transposes[core.named_call_p] = \
|
||||
partial(call_transpose, core.named_call_p)
|
||||
|
||||
|
||||
def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes):
|
||||
|
@ -31,10 +31,9 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, Zero)
|
||||
from jax import linear_util as lu
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name,
|
||||
split_list, canonicalize_axis, moveaxis,
|
||||
as_hashable_function, curry, memoize,
|
||||
weakref_lru_cache)
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize, weakref_lru_cache)
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
Array = Any
|
||||
@ -352,10 +351,7 @@ class BatchTrace(Trace):
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
if config.jax_experimental_name_stack:
|
||||
params = dict(params, name=params.get('name', f.__name__))
|
||||
else:
|
||||
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
if all(bdim is not_mapped for bdim in dims):
|
||||
return call_primitive.bind(f, *vals, **params)
|
||||
|
@ -309,13 +309,9 @@ register_constant_handler(
|
||||
def _source_info_to_location(
|
||||
primitive: core.Primitive, params: Dict,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location:
|
||||
if config.jax_experimental_name_stack:
|
||||
name_stack: source_info_util.NameStack) -> ir.Location:
|
||||
eqn_str = (f'{str(source_info.name_stack)}/'
|
||||
f'{core.str_eqn_compact(primitive.name, params)}')
|
||||
else:
|
||||
assert isinstance(name_stack, str)
|
||||
eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
if frame is None:
|
||||
loc = ir.Location.unknown()
|
||||
@ -328,8 +324,6 @@ def _source_info_to_location(
|
||||
|
||||
|
||||
# Translation rules
|
||||
NameStack = Union[str, source_info_util.NameStack]
|
||||
|
||||
def make_ir_context() -> ir.Context:
|
||||
"""Creates an MLIR context suitable for JAX IR."""
|
||||
context = ir.Context()
|
||||
@ -414,7 +408,7 @@ class ModuleContext:
|
||||
backend_or_name: Optional[Union[str, xb.XlaBackend]]
|
||||
platform: str
|
||||
axis_context: AxisContext
|
||||
name_stack: NameStack
|
||||
name_stack: source_info_util.NameStack
|
||||
keepalives: List[Any]
|
||||
channel_iterator: Iterator[int]
|
||||
host_callbacks: List[Any]
|
||||
@ -432,7 +426,7 @@ class ModuleContext:
|
||||
backend_or_name: Optional[Union[str, xb.XlaBackend]],
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
name_stack: NameStack,
|
||||
name_stack: source_info_util.NameStack,
|
||||
keepalives: List[Any],
|
||||
channel_iterator: Iterator[int],
|
||||
host_callbacks: List[Any],
|
||||
@ -583,7 +577,7 @@ def lower_jaxpr_to_module(
|
||||
backend_or_name: Optional[Union[str, xb.XlaBackend]],
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
name_stack: NameStack,
|
||||
name_stack: source_info_util.NameStack,
|
||||
donated_args: Sequence[bool],
|
||||
replicated_args: Optional[Sequence[bool]] = None,
|
||||
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
@ -1007,14 +1001,11 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
map(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_nodes = map(read, eqn.invars)
|
||||
if config.jax_experimental_name_stack:
|
||||
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=ctx.name_stack + eqn.source_info.name_stack)
|
||||
else:
|
||||
source_info = eqn.source_info
|
||||
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
|
||||
name_stack=ctx.name_stack)
|
||||
ctx.name_stack)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
|
||||
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
|
||||
@ -1029,8 +1020,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
|
||||
f"found for platform {ctx.platform}")
|
||||
|
||||
eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
|
||||
config.jax_experimental_name_stack else ctx)
|
||||
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
|
||||
effects = [eff for eff in eqn.effects if eff in core.ordered_effects]
|
||||
tokens_in = tokens.subset(effects)
|
||||
avals_in = map(aval, eqn.invars)
|
||||
@ -1160,21 +1150,16 @@ def _xla_call_lower(ctx, *args,
|
||||
|
||||
register_lowering(xla.xla_call_p, _xla_call_lower)
|
||||
|
||||
def _named_call_lowering(ctx, *args, name, backend=None,
|
||||
call_jaxpr):
|
||||
def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
||||
out_nodes, tokens = _call_lowering(
|
||||
name, name, call_jaxpr, backend, ctx.module_context,
|
||||
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out_nodes
|
||||
|
||||
register_lowering(core.named_call_p, _named_call_lowering)
|
||||
register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
|
||||
register_lowering(core.call_p, partial(_core_call_lowering, name="core_call"))
|
||||
register_lowering(core.closed_call_p,
|
||||
partial(_named_call_lowering, name="core_closed_call"))
|
||||
|
||||
register_lowering(core.closed_call_p,
|
||||
partial(_named_call_lowering, name="core_closed_call"))
|
||||
partial(_core_call_lowering, name="core_closed_call"))
|
||||
|
||||
|
||||
def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
|
||||
|
@ -1291,9 +1291,6 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
|
||||
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr')
|
||||
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
|
||||
|
||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
@ -1394,7 +1391,6 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
|
||||
|
||||
|
||||
def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
|
@ -104,11 +104,8 @@ def make_op_metadata(primitive: core.Primitive,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
name_stack: Union[str, source_info_util.NameStack] = "",
|
||||
) -> xc.OpMetadata:
|
||||
if config.jax_experimental_name_stack:
|
||||
eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
|
||||
else:
|
||||
assert isinstance(name_stack, str)
|
||||
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
|
||||
eqn_str = (str(source_info.name_stack) + '/'
|
||||
+ str_eqn_compact(primitive.name, params))
|
||||
tracebacks[eqn_str] = source_info.traceback
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
return xc.OpMetadata(
|
||||
|
@ -9286,12 +9286,9 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
return my_test_function(x)
|
||||
|
||||
c = jax.xla_computation(f)(2)
|
||||
if config.jax_experimental_name_stack:
|
||||
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
|
||||
print_opts.print_metadata = True
|
||||
hlo_text = c.as_hlo_module().to_string(print_opts)
|
||||
else:
|
||||
hlo_text = c.as_hlo_text()
|
||||
self.assertIn("my_test_function", hlo_text)
|
||||
|
||||
def test_non_jaxtype_arg(self):
|
||||
|
@ -18,7 +18,6 @@ from absl.testing import absltest
|
||||
import jax.numpy as jnp
|
||||
from jax.tools import jax_to_ir
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
@ -90,14 +89,9 @@ class JaxToIRTest(absltest.TestCase):
|
||||
])
|
||||
|
||||
# Check that tf debug txt contains a broadcast, add, and multiply.
|
||||
if config.jax_experimental_name_stack:
|
||||
self.assertIn('BroadcastTo', tf_text)
|
||||
self.assertIn('AddV2', tf_text)
|
||||
self.assertIn('Mul', tf_text)
|
||||
else:
|
||||
self.assertIn('name: "BroadcastTo"', tf_text)
|
||||
self.assertIn('name: "AddV2"', tf_text)
|
||||
self.assertIn('name: "Mul"', tf_text)
|
||||
|
||||
# Check that we can re-import our graphdef.
|
||||
gdef = tf.compat.v1.GraphDef()
|
||||
|
@ -63,14 +63,9 @@ class MetadataTest(jtu.JaxTestCase):
|
||||
def foo(x):
|
||||
return jnp.sin(x)
|
||||
hlo = jax.xla_computation(jax.grad(foo))(1.).get_hlo_module().to_string()
|
||||
if config.jax_experimental_name_stack:
|
||||
self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/sin"')
|
||||
self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/cos"')
|
||||
self.assertRegex(hlo, 'op_name=".*transpose\\(jvp\\(jit\\(foo\\)\\)\\)/mul"')
|
||||
else:
|
||||
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
|
||||
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
|
||||
self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\(jvp\\(foo\\)\\)\\)/mul"')
|
||||
|
||||
def test_cond_metadata(self):
|
||||
def true_fun(x):
|
||||
|
@ -33,17 +33,8 @@ def _get_hlo(f):
|
||||
return c.as_hlo_module().to_string(print_opts)
|
||||
return wrapped
|
||||
|
||||
class _EnableNameStackTestCase(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = config._read("jax_experimental_name_stack")
|
||||
config.update("jax_experimental_name_stack", True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_experimental_name_stack", self.cfg)
|
||||
|
||||
|
||||
class NameStackTest(_EnableNameStackTestCase):
|
||||
class NameStackTest(jtu.JaxTestCase):
|
||||
|
||||
def test_trivial_name_stack(self):
|
||||
|
||||
@ -135,7 +126,7 @@ class NameStackTest(_EnableNameStackTestCase):
|
||||
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
|
||||
|
||||
class NameStackTransformationTest(_EnableNameStackTestCase):
|
||||
class NameStackTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
def test_vmap_should_transform_name_stack(self):
|
||||
@jax.vmap
|
||||
@ -238,7 +229,7 @@ class NameStackTransformationTest(_EnableNameStackTestCase):
|
||||
self.assertIn('transpose(jvp(foo))/jit(f)/bar/mul', hlo_text)
|
||||
|
||||
|
||||
class NameStackControlFlowTest(_EnableNameStackTestCase):
|
||||
class NameStackControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def test_while_loop_body_should_not_have_name_stack(self):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user