Delete jax_experimental_name_stack flag

PiperOrigin-RevId: 487601864
This commit is contained in:
Sharad Vikram 2022-11-10 11:59:16 -08:00 committed by jax authors
parent 0ebb6b4215
commit 74b136e62c
16 changed files with 59 additions and 179 deletions

View File

@ -3290,20 +3290,8 @@ def named_call(
if name is None:
name = fun.__name__
_, in_tree = tree_flatten(())
if config.jax_experimental_name_stack:
return source_info_util.extend_name_stack(name)(fun)
@functools.wraps(fun)
def named_call_f(*args, **kwargs):
lu_f = lu.wrap_init(lambda: fun(*args, **kwargs))
flat_f, out_tree = flatten_fun_nokwargs(lu_f, in_tree)
out_flat = core.named_call_p.bind(flat_f, name=name)
return tree_unflatten(out_tree(), out_flat)
return named_call_f
@contextmanager
def named_scope(
name: str,

View File

@ -1001,11 +1001,6 @@ config.define_bool_state(
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(dynamic_shapes=val))
config.define_bool_state(
name='jax_experimental_name_stack',
default=True,
help='Enable using the context manager-based name stack.')
# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
config.define_bool_state(

View File

@ -333,21 +333,16 @@ def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'
def new_name_stack(name: str = ''):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
name_stack = source_info_util.NameStack()
if name:
name_stack = name_stack.extend(name)
return name_stack
return name + '/'
def extend_name_stack(stack, name: str):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
assert isinstance(stack, source_info_util.NameStack), stack
return stack.extend(name)
assert isinstance(stack, str)
return stack + name + '/'
def canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""

View File

@ -2073,9 +2073,6 @@ call_p: CallPrimitive = CallPrimitive('call')
call = call_p.bind
call_p.def_impl(call_impl)
named_call_p: CallPrimitive = CallPrimitive('named_call')
named_call_p.def_impl(call_impl)
class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):

View File

@ -1651,16 +1651,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
# compilation to XLA, which does not use those parameters.
bwd="illegal param",
out_trees="illegal param")))
elif eqn.primitive is core.named_call_p:
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
)))
elif eqn.primitive is pjit.pjit_p:
jaxpr = cast(core.ClosedJaxpr, eqn.params["jaxpr"])
eqns.append(

View File

@ -187,9 +187,7 @@ class _ThreadLocalState(threading.local):
_thread_local_state = _ThreadLocalState()
def _get_current_name_stack() -> Union[NameStack, str]:
if config.jax_experimental_name_stack:
return source_info_util.current_name_stack()
return _thread_local_state.name_stack
@contextlib.contextmanager
def inside_call_tf():
@ -530,24 +528,12 @@ def make_custom_gradient_fn_tf(
@contextlib.contextmanager
def _extended_name_stack(extra_name_stack: Optional[str]):
if config.jax_experimental_name_stack:
name_ctx = (source_info_util.extend_name_stack(extra_name_stack)
if extra_name_stack
else contextlib.nullcontext())
with name_ctx:
yield
return
prev_name_stack = _thread_local_state.name_stack
if extra_name_stack:
if not prev_name_stack:
_thread_local_state.name_stack = extra_name_stack
else:
_thread_local_state.name_stack = util.extend_name_stack(
_thread_local_state.name_stack, extra_name_stack)
try:
yield
finally:
_thread_local_state.name_stack = prev_name_stack
def _interpret_fun_jax(
@ -1050,20 +1036,14 @@ class TensorFlowTrace(core.Trace):
return impl(*args_tf, **params)
current_name_stack = _get_current_name_stack()
if config.jax_experimental_name_stack:
# We don't use `str(name_stack)` because it uses parentheses for
# transformations, which aren't allowed in `name_scope`.
scope = '/'.join([s.name for s in current_name_stack.stack]) # type: ignore[union-attr]
else:
scope = str(current_name_stack)
# We need to add a '/' to the name stack string to force `tf.name_scope`
# to interpret it as an absolute scope, not a relative scope.
scope = scope + '/'
name_scope = (
tf.name_scope(_sanitize_scope_name(scope)) if
config.jax_experimental_name_stack else contextlib.nullcontext())
with name_scope:
with tf.name_scope(_sanitize_scope_name(scope)):
if _thread_local_state.include_xla_op_metadata:
op_metadata = xla.make_op_metadata(primitive, params,
name_stack=current_name_stack,
@ -1109,17 +1089,11 @@ class TensorFlowTrace(core.Trace):
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
extra_name_stack = None
if call_primitive == core.named_call_p:
extra_name_stack = util.wrap_name(params["name"], "named")
elif call_primitive == xla.xla_call_p:
if call_primitive == xla.xla_call_p:
extra_name_stack = util.wrap_name(params["name"], "jit")
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
if call_primitive == core.named_call_p:
with tf.name_scope(_sanitize_scope_name(params["name"])):
vals_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
interpreted_fun.call_wrapped(*vals)
elif call_primitive == xla.xla_call_p:
if call_primitive == xla.xla_call_p:
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
# Make a nested tf.function(jit_compile=True)
store_tf_res_avals: Sequence[core.ShapedArray] = []
@ -1208,7 +1182,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
# Call primitives are inlined
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, maps.xmap_p]:
for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.
@ -2615,7 +2589,6 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
for jaxpr in branches
for i, jaxpr in enumerate(branches)
]
if config.jax_experimental_name_stack:
# Same name stack as XLA translation of cond_p
# Note: extend_name_stack is a contextmanager, which is callable as a decorator.
branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type]

View File

@ -771,12 +771,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_tf_hlo = self.TfToHlo(f_tf, arg)
if jax.config.jax_remat_opt_barrier:
self.assertRegex(f_tf_hlo, r"opt-barrier")
elif config.jax_experimental_name_stack:
self.assertRegex(f_tf_hlo,
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
else:
self.assertRegex(f_tf_hlo,
r'switch_case/indexed_case/Sin')
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
def test_remat_free_var(self):
def f(x):

View File

@ -319,8 +319,6 @@ class JVPTrace(Trace):
which_nz = [ type(t) is not Zero for t in tangents]
tangents = [t if type(t) is not Zero else None for t in tangents]
args, in_tree = tree_flatten((primals, tangents))
if 'name' in params and not config.jax_experimental_name_stack:
params = dict(params, name=wrap_name(params['name'], 'jvp'))
f_jvp = jvp_subtrace(f, self.main)
f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp)
if isinstance(call_primitive, core.MapPrimitive):
@ -609,8 +607,6 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
reduce_axes, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
if 'name' in params and not config.jax_experimental_name_stack:
params = dict(params, name=wrap_name(params['name'], 'transpose'))
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
params = update_params(params, map(is_undefined_primal, args),
@ -627,8 +623,6 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[core.named_call_p] = \
partial(call_transpose, core.named_call_p)
def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes):

View File

@ -31,10 +31,9 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten,
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero)
from jax import linear_util as lu
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name,
split_list, canonicalize_axis, moveaxis,
as_hashable_function, curry, memoize,
weakref_lru_cache)
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)
from jax.interpreters import partial_eval as pe
Array = Any
@ -352,10 +351,7 @@ class BatchTrace(Trace):
def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
if config.jax_experimental_name_stack:
params = dict(params, name=params.get('name', f.__name__))
else:
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)

View File

@ -309,13 +309,9 @@ register_constant_handler(
def _source_info_to_location(
primitive: core.Primitive, params: Dict,
source_info: source_info_util.SourceInfo,
name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location:
if config.jax_experimental_name_stack:
name_stack: source_info_util.NameStack) -> ir.Location:
eqn_str = (f'{str(source_info.name_stack)}/'
f'{core.str_eqn_compact(primitive.name, params)}')
else:
assert isinstance(name_stack, str)
eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
frame = source_info_util.user_frame(source_info)
if frame is None:
loc = ir.Location.unknown()
@ -328,8 +324,6 @@ def _source_info_to_location(
# Translation rules
NameStack = Union[str, source_info_util.NameStack]
def make_ir_context() -> ir.Context:
"""Creates an MLIR context suitable for JAX IR."""
context = ir.Context()
@ -414,7 +408,7 @@ class ModuleContext:
backend_or_name: Optional[Union[str, xb.XlaBackend]]
platform: str
axis_context: AxisContext
name_stack: NameStack
name_stack: source_info_util.NameStack
keepalives: List[Any]
channel_iterator: Iterator[int]
host_callbacks: List[Any]
@ -432,7 +426,7 @@ class ModuleContext:
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
axis_context: AxisContext,
name_stack: NameStack,
name_stack: source_info_util.NameStack,
keepalives: List[Any],
channel_iterator: Iterator[int],
host_callbacks: List[Any],
@ -583,7 +577,7 @@ def lower_jaxpr_to_module(
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
axis_context: AxisContext,
name_stack: NameStack,
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
@ -1007,14 +1001,11 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
if config.jax_experimental_name_stack:
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
else:
source_info = eqn.source_info
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
name_stack=ctx.name_stack)
ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
@ -1029,8 +1020,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")
eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
config.jax_experimental_name_stack else ctx)
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
effects = [eff for eff in eqn.effects if eff in core.ordered_effects]
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
@ -1160,21 +1150,16 @@ def _xla_call_lower(ctx, *args,
register_lowering(xla.xla_call_p, _xla_call_lower)
def _named_call_lowering(ctx, *args, name, backend=None,
call_jaxpr):
def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
out_nodes, tokens = _call_lowering(
name, name, call_jaxpr, backend, ctx.module_context,
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args)
ctx.set_tokens_out(tokens)
return out_nodes
register_lowering(core.named_call_p, _named_call_lowering)
register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
register_lowering(core.call_p, partial(_core_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))
register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))
partial(_core_call_lowering, name="core_closed_call"))
def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:

View File

@ -1291,9 +1291,6 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
lambda _, __, ___, ____, _____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr')
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, _____, x, y: (x, y))
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
@ -1394,7 +1391,6 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn

View File

@ -104,11 +104,8 @@ def make_op_metadata(primitive: core.Primitive,
source_info: source_info_util.SourceInfo,
name_stack: Union[str, source_info_util.NameStack] = "",
) -> xc.OpMetadata:
if config.jax_experimental_name_stack:
eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
else:
assert isinstance(name_stack, str)
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
eqn_str = (str(source_info.name_stack) + '/'
+ str_eqn_compact(primitive.name, params))
tracebacks[eqn_str] = source_info.traceback
frame = source_info_util.user_frame(source_info)
return xc.OpMetadata(

View File

@ -9286,12 +9286,9 @@ class NamedCallTest(jtu.JaxTestCase):
return my_test_function(x)
c = jax.xla_computation(f)(2)
if config.jax_experimental_name_stack:
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = True
hlo_text = c.as_hlo_module().to_string(print_opts)
else:
hlo_text = c.as_hlo_text()
self.assertIn("my_test_function", hlo_text)
def test_non_jaxtype_arg(self):

View File

@ -18,7 +18,6 @@ from absl.testing import absltest
import jax.numpy as jnp
from jax.tools import jax_to_ir
from jax._src import test_util as jtu
from jax.config import config
try:
import tensorflow as tf
@ -90,14 +89,9 @@ class JaxToIRTest(absltest.TestCase):
])
# Check that tf debug txt contains a broadcast, add, and multiply.
if config.jax_experimental_name_stack:
self.assertIn('BroadcastTo', tf_text)
self.assertIn('AddV2', tf_text)
self.assertIn('Mul', tf_text)
else:
self.assertIn('name: "BroadcastTo"', tf_text)
self.assertIn('name: "AddV2"', tf_text)
self.assertIn('name: "Mul"', tf_text)
# Check that we can re-import our graphdef.
gdef = tf.compat.v1.GraphDef()

View File

@ -63,14 +63,9 @@ class MetadataTest(jtu.JaxTestCase):
def foo(x):
return jnp.sin(x)
hlo = jax.xla_computation(jax.grad(foo))(1.).get_hlo_module().to_string()
if config.jax_experimental_name_stack:
self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/sin"')
self.assertRegex(hlo, 'op_name=".*jvp\\(jit\\(foo\\)\\)/cos"')
self.assertRegex(hlo, 'op_name=".*transpose\\(jvp\\(jit\\(foo\\)\\)\\)/mul"')
else:
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\(jvp\\(foo\\)\\)\\)/mul"')
def test_cond_metadata(self):
def true_fun(x):

View File

@ -33,17 +33,8 @@ def _get_hlo(f):
return c.as_hlo_module().to_string(print_opts)
return wrapped
class _EnableNameStackTestCase(jtu.JaxTestCase):
def setUp(self):
self.cfg = config._read("jax_experimental_name_stack")
config.update("jax_experimental_name_stack", True)
def tearDown(self):
config.update("jax_experimental_name_stack", self.cfg)
class NameStackTest(_EnableNameStackTestCase):
class NameStackTest(jtu.JaxTestCase):
def test_trivial_name_stack(self):
@ -135,7 +126,7 @@ class NameStackTest(_EnableNameStackTestCase):
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
class NameStackTransformationTest(_EnableNameStackTestCase):
class NameStackTransformationTest(jtu.JaxTestCase):
def test_vmap_should_transform_name_stack(self):
@jax.vmap
@ -238,7 +229,7 @@ class NameStackTransformationTest(_EnableNameStackTestCase):
self.assertIn('transpose(jvp(foo))/jit(f)/bar/mul', hlo_text)
class NameStackControlFlowTest(_EnableNameStackTestCase):
class NameStackControlFlowTest(jtu.JaxTestCase):
def test_while_loop_body_should_not_have_name_stack(self):