Add separate mechanism for threading name stacks to the lowering

This commit is contained in:
Sharad Vikram 2021-10-28 11:06:58 -07:00
parent e96b91d405
commit 1b79caa6bd
24 changed files with 937 additions and 142 deletions

View File

@ -353,7 +353,7 @@ def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, consts, dummy_args,
in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
out_cts)
in_cts, cell.treedef = tree_flatten(in_cts_)
return in_cts

View File

@ -56,10 +56,11 @@ from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
Partial, PyTreeDef, all_leaves)
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, wrap_name, cache, wraps,
extend_name_stack, new_name_stack, wrap_name, cache, wraps,
HashableFunction)
from jax._src import device_array
from jax._src import dispatch
from jax._src import source_info_util
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
@ -895,9 +896,8 @@ def xla_computation(fun: Callable,
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
ctx = xla.TranslationContext(
c, backend, axis_env_,
extend_name_stack(wrap_name(fun_name, "xla_computation")))
name_stack = new_name_stack(wrap_name(fun_name, "xla_computation"))
ctx = xla.TranslationContext(c, backend, axis_env_, name_stack)
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
if out_parts is not None:
@ -2615,7 +2615,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cotangents = map(
ad.instantiate_zeros,
ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents))
ad.backward_pass(jaxpr, reduce_axes, True, consts, dummies, out_cotangents))
return tree_unflatten(in_tree, in_cotangents)
# Ensure that transposed_fun is a PyTree
@ -3197,6 +3197,9 @@ def named_call(
_, in_tree = tree_flatten(())
if config.jax_experimental_name_stack:
return source_info_util.extend_name_stack(name)(fun)
@functools.wraps(fun)
def named_call_f(*args, **kwargs):
lu_f = lu.wrap_init(lambda: fun(*args, **kwargs))

View File

@ -140,7 +140,6 @@ class Config:
for name, val in self.values.items():
flag_type, meta_args, meta_kwargs = self.meta[name]
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
def complete_absl_config(self, absl_flags):
@ -688,6 +687,11 @@ config.define_bool_state(
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'))
config.define_bool_state(
name='jax_experimental_name_stack',
default=False,
help='Enable using the context manager-based name stack.')
# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
config.define_bool_state(

View File

@ -407,7 +407,7 @@ def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
jvp_jaxpr_thunk, num_consts):
del jvp_jaxpr_thunk, num_consts
return ad.backward_pass(
fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts)
fun_jaxpr.jaxpr, reduce_axes, False, fun_jaxpr.consts, args, cts)
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
def custom_jvp_jaxpr_custom_partial_eval_rule(

View File

@ -250,7 +250,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
# pass long arg lists as tuple for TPU
tuple_args = len(abstract_args) > 100
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit'))
name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"jit_{fun.__name__}"

View File

@ -347,8 +347,9 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
name_stack = extend_name_stack(ctx.name_stack, 'while')
cond_ctx = ctx.replace(builder=cond_c,
name_stack=extend_name_stack(ctx.name_stack, 'cond'))
name_stack=extend_name_stack(name_stack, 'cond'))
pred, = xla.jaxpr_subcomp(
cond_ctx, cond_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts),
@ -365,14 +366,14 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
body_ctx = ctx.replace(builder=body_c,
name_stack=extend_name_stack(ctx.name_stack, 'body'))
name_stack=extend_name_stack(name_stack, 'body'))
new_z = xla.jaxpr_subcomp(
body_ctx, body_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = body_ctx.replace(
name_stack=extend_name_stack(ctx.name_stack, 'body_pred'))
name_stack=extend_name_stack(name_stack, 'body_pred'))
body_pred, = xla.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts),
@ -1201,9 +1202,11 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
linear_2 = (False,) * num_res + linear
params = dict(branches=branches_2, linear=linear_2)
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
source_info_util.current())
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -1297,7 +1300,7 @@ def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
cts_in = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out)
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
@ -1924,9 +1927,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])]
other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
in_pvals_1 = invariant_pvals + other_pvals
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
with source_info_util.reset_name_stack():
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ())
num_consts_1 = num_consts + len(consts_1)
# any now-known residuals are intensive, so we want to revise jaxpr_2 to take
@ -1990,6 +1994,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
for pv, const in zip(out_pvs, out_consts)]
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
linear_2 = ([False] * len(int_res_tracers) +
[lin or not uk for uk, lin in zip(unknowns, linear)] +
[False] * len(ext_res_tracers))
@ -1999,7 +2005,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
num_consts=num_consts_2,
num_carry=num_carry, linear=tuple(linear_2),
unroll=unroll),
source_info_util.current())
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -2068,7 +2074,7 @@ def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts,
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
primals, b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)

View File

@ -50,7 +50,7 @@ from jax.interpreters import masking
import jax._src.pretty_printer as pp
from jax._src import util
from jax._src.util import (cache, safe_zip, prod, safe_map, canonicalize_axis,
split_list)
split_list, new_name_stack)
from jax.tree_util import tree_map
import jax._src.lib
from jax._src.lib import pytree
@ -3424,7 +3424,7 @@ def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True):
subc = xc.XlaBuilder("reduction_computation")
assert len(consts) == 0, "Reduction computations cannot have constants"
args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
ctx = xla.TranslationContext(subc, platform, axis_env, '')
ctx = xla.TranslationContext(subc, platform, axis_env, new_name_stack())
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
if singleton:
return subc.build(out_nodes[0])

View File

@ -13,12 +13,13 @@
# limitations under the License.
import contextlib
import dataclasses
import functools
import itertools
import os.path
import threading
import types
from typing import Optional, Iterator, NamedTuple
from typing import Optional, Iterator, NamedTuple, Union, Tuple
import jax.version
from jax._src.lib import xla_client, xla_extension_version
@ -40,15 +41,66 @@ _exclude_paths = [os.path.dirname(jax.version.__file__)]
def register_exclusion(path):
_exclude_paths.append(path)
class Scope(NamedTuple):
name: str
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
return (self.name, *stack)
class Transform(NamedTuple):
name: str
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(map(lambda x: f'{self.name}({x})', stack))
@dataclasses.dataclass(frozen=True)
class NameStack:
stack: Tuple[Union[Scope, Transform], ...] = ()
def extend(self, name: Union[Tuple[str, ...], str]) -> 'NameStack':
if not isinstance(name, tuple):
name = (name,)
scopes = tuple(map(Scope, name))
return NameStack(self.stack + scopes)
def wrap_name(self, name: str) -> str:
if not self.stack:
return name
return f'{str(self)}/{name}'
def transform(self, transform_name: str) -> 'NameStack':
return NameStack((*self.stack, Transform(transform_name)))
def __getitem__(self, idx) -> 'NameStack':
return NameStack(self.stack[idx])
def __len__(self):
return len(self.stack)
def __add__(self, other: 'NameStack') -> 'NameStack':
return NameStack(self.stack + other.stack)
def __radd__(self, other: 'NameStack') -> 'NameStack':
return NameStack(other.stack + self.stack)
def __str__(self) -> str:
scope: Tuple[str, ...] = ()
for elem in self.stack[::-1]:
scope = elem.wrap(scope)
return '/'.join(scope)
class SourceInfo(NamedTuple):
traceback: Optional[Traceback]
name_stack: NameStack
def replace(self, *, traceback: Optional[Traceback] = None) -> 'SourceInfo':
def replace(self, *, traceback: Optional[Traceback] = None,
name_stack: Optional[NameStack] = None) -> 'SourceInfo':
traceback = traceback or self.traceback
return self._replace(traceback=traceback)
name_stack = self.name_stack if name_stack is None else name_stack
return self._replace(traceback=traceback, name_stack=name_stack)
def new_source_info() -> SourceInfo:
return SourceInfo(None)
return SourceInfo(None, NameStack())
def is_user_filename(filename: str) -> bool:
"""Heuristic that guesses the identity of the user's code in a stack trace."""
@ -97,11 +149,10 @@ class _SourceInfoContext(threading.local):
_source_info_context = _SourceInfoContext()
def current() -> SourceInfo:
context = _source_info_context.context
if not context.traceback:
return context.replace(traceback=xla_client.Traceback.get_traceback())
return context
source_info = _source_info_context.context
if not source_info.traceback:
source_info = source_info.replace(traceback=xla_client.Traceback.get_traceback())
return source_info
class JaxStackTraceBeforeTransformation(Exception): pass
@ -118,9 +169,10 @@ def has_user_context(e):
return False
@contextlib.contextmanager
def user_context(c: Optional[Traceback]):
def user_context(c: Optional[Traceback], *, name_stack: Optional[NameStack] = None):
prev = _source_info_context.context
_source_info_context.context = _source_info_context.context.replace(traceback=c)
_source_info_context.context = _source_info_context.context.replace(
traceback=c, name_stack=name_stack)
filtered_tb = None
try:
yield
@ -141,3 +193,43 @@ def user_context(c: Optional[Traceback]):
finally:
_source_info_context.context = prev
del filtered_tb
def current_name_stack() -> NameStack:
return _source_info_context.context.name_stack
@contextlib.contextmanager
def extend_name_stack(name: str) -> Iterator[NameStack]:
prev_context = _source_info_context.context
curr_name_stack = prev_context.name_stack
new_context = prev_context.replace(name_stack=curr_name_stack.extend(name))
_source_info_context.context = new_context
try:
yield _source_info_context.context.name_stack
finally:
_source_info_context.context = prev_context
@contextlib.contextmanager
def set_name_stack(name_stack: NameStack) -> Iterator[None]:
prev_context = _source_info_context.context
new_context = prev_context.replace(name_stack=name_stack)
_source_info_context.context = new_context
try:
yield
finally:
_source_info_context.context = prev_context
@contextlib.contextmanager
def reset_name_stack() -> Iterator[None]:
with set_name_stack(NameStack()):
yield
@contextlib.contextmanager
def transform_name_stack(name: str) -> Iterator[NameStack]:
prev_context = _source_info_context.context
curr_name_stack = prev_context.name_stack
new_context = prev_context.replace(name_stack=curr_name_stack.transform(name))
_source_info_context.context = new_context
try:
yield _source_info_context.context.name_stack
finally:
_source_info_context.context = prev_context

View File

@ -277,7 +277,21 @@ def get_module_functions(module):
def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'
def extend_name_stack(stack, name=''):
def new_name_stack(name: str = ''):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
name_stack = source_info_util.NameStack()
if name:
name_stack = name_stack.extend(name)
return name_stack
return name + '/'
def extend_name_stack(stack, name: str):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
assert isinstance(stack, source_info_util.NameStack), stack
return stack.extend(name)
assert isinstance(stack, str)
return stack + name + '/'
def canonicalize_axis(axis, num_dims) -> int:

View File

@ -81,10 +81,11 @@ class Jaxpr:
__repr__ = __str__
def pretty_print(self, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, **kw):
custom_pp_eqn_rules=True, name_stack=False, **kw):
doc = pp_jaxpr(self, JaxprPpContext(), source_info=source_info,
print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules)
custom_pp_eqn_rules=custom_pp_eqn_rules,
name_stack=name_stack)
return doc.format(**kw)
def _repr_pretty_(self, p, cycle):
@ -141,9 +142,10 @@ class ClosedJaxpr:
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
def pretty_print(self, *, source_info=False, print_shapes=True,
name_stack=False, **kw):
return pp_jaxpr(self.jaxpr, JaxprPpContext(), source_info=source_info,
print_shapes=print_shapes).format(**kw)
print_shapes=print_shapes, name_stack=name_stack).format(**kw)
def _repr_pretty_(self, p, cycle):
@ -333,7 +335,8 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
with source_info_util.user_context(eqn.source_info.traceback):
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
@ -2272,50 +2275,52 @@ def pp_vars(vs: Sequence[Any], context: JaxprPpContext,
[pp.text(pp_var(v, context)) for v in vs])
))
def pp_kv_pair(k:str, v: Any, context: JaxprPpContext) -> pp.Doc:
def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc:
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
pp_v = pp_jaxprs(v, context)
pp_v = pp_jaxprs(v, context, name_stack=name_stack)
elif isinstance(v, Jaxpr):
pp_v = pp_jaxpr(v, context)
pp_v = pp_jaxpr(v, context, name_stack=name_stack)
elif isinstance(v, ClosedJaxpr):
pp_v = pp_jaxpr(v.jaxpr, context)
pp_v = pp_jaxpr(v.jaxpr, context, name_stack=name_stack)
else:
pp_v = pp.text(str(v))
return pp.text(f'{k}=') + pp_v
def pp_kv_pairs(kv_pairs, context: JaxprPpContext) -> pp.Doc:
def pp_kv_pairs(kv_pairs, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc:
if not kv_pairs:
return pp.nil()
return pp.group(
pp.nest(2, pp.concat([
pp.text("["), pp.brk(""),
pp.join(pp.brk(), [pp_kv_pair(k, v, context) for k, v in kv_pairs])
pp.join(pp.brk(), [pp_kv_pair(k, v, context, name_stack=name_stack) for k, v in kv_pairs])
]))
+ pp.brk("") + pp.text("]")
)
def pp_eqn(eqn, context: JaxprPpContext, *, print_shapes=True,
source_info=False, custom_pp_eqn_rules=True) -> pp.Doc:
source_info=False, custom_pp_eqn_rules=True, name_stack=False) -> pp.Doc:
lhs = pp_vars(eqn.outvars, context, print_shapes=print_shapes)
annotation = (source_info_util.summarize(eqn.source_info)
if source_info else None)
rule = pp_eqn_rules.get(eqn.primitive)
name_stack_annotation = f'[{eqn.source_info.name_stack}]' if name_stack else None
if rule and custom_pp_eqn_rules:
rhs = rule(eqn, context)
else:
rhs = [pp.text(eqn.primitive.name),
pp_kv_pairs(sorted(eqn.params.items()), context),
rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation),
pp_kv_pairs(sorted(eqn.params.items()), context, name_stack=name_stack),
pp.text(" ") + pp_vars(eqn.invars, context)]
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext], Sequence[pp.Doc]]
pp_eqn_rules: Dict[Primitive, CustomPpEqnRule] = {}
def pp_eqns(eqns, context: JaxprPpContext, *, print_shapes=True,
source_info=False, custom_pp_eqn_rules=True
source_info=False, custom_pp_eqn_rules=True, name_stack=False,
) -> pp.Doc:
return pp.join(
pp.brk("; "),
[pp_eqn(e, context, print_shapes=print_shapes, source_info=source_info,
name_stack=name_stack,
custom_pp_eqn_rules=custom_pp_eqn_rules) for e in eqns])
def _compact_eqn_should_include(k: str, v: Any) -> bool:
@ -2349,23 +2354,25 @@ def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext, *,
def pp_jaxpr(jaxpr, context: JaxprPpContext, *, print_shapes=True,
source_info=False, custom_pp_eqn_rules=True) -> pp.Doc:
source_info=False, custom_pp_eqn_rules=True, name_stack=False) -> pp.Doc:
eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, print_shapes=print_shapes,
source_info=source_info,
custom_pp_eqn_rules=custom_pp_eqn_rules)
custom_pp_eqn_rules=custom_pp_eqn_rules,
name_stack=name_stack)
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, print_shapes=print_shapes)
def pp_jaxprs(jaxprs, context: JaxprPpContext) -> pp.Doc:
def pp_jaxprs(jaxprs, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc:
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
return pp.group(pp.nest(2, pp.concat([
pp.text('('), pp.brk(""),
pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context), jaxprs))]
pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, name_stack=name_stack), jaxprs))]
)) + pp.brk("") + pp.text(')')
)
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext,
print_shapes=True, source_info: bool = False) -> pp.Doc:
print_shapes=True, source_info: bool = False,
name_stack: bool = False) -> pp.Doc:
lo = max(lo, 0)
hi = max(lo, min(hi, len(jaxpr.eqns)))
eqns = jaxpr.eqns[lo:hi]
@ -2377,7 +2384,8 @@ def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext,
if lo != 0:
pps.append(pp.text('...'))
pps.extend(map((lambda e: pp_eqn(e, context, print_shapes=print_shapes,
source_info=source_info)), eqns))
source_info=source_info,
name_stack=name_stack)), eqns))
if hi != len(jaxpr.eqns):
pps.append(pp.text('...'))
return pp.join(pp.brk("; "), pps)

View File

@ -24,7 +24,7 @@ from jax._src import source_info_util
from jax.core import Var, Literal, Atom, Tracer
from jax._src import util
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
tuple_delete)
tuple_delete, new_name_stack)
import jax._src.pretty_printer as pp
map = safe_map
@ -806,7 +806,8 @@ def traceable_to_padded_translation(traceable):
operands_ = it.chain.from_iterable([*dims.values(), *operands])
platform = "cpu" # TODO: don't hardwire in the CPU translation.
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(1, (), ()), '')
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(1, (), ()),
new_name_stack())
outs = xla.jaxpr_subcomp(ctx, jaxpr, xla._xla_consts(c, consts), *operands_)
return util.unflatten(outs,
[aval_to_num_buffers(aval) for aval in out_avals])

View File

@ -261,7 +261,7 @@ def convert(fun: Callable,
"""
api._check_callable(fun)
fun_name = getattr(fun, "__name__", "unknown")
name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf"))
name_stack = util.wrap_name(fun_name, "jax2tf") + "/"
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:

View File

@ -117,6 +117,7 @@ from jax._src.lax import control_flow as lax_control_flow
from jax import tree_util
from jax.errors import UnexpectedTracerError
from jax.interpreters import partial_eval as pe
from jax._src import source_info_util
from jax._src.util import safe_map
@ -291,10 +292,11 @@ class Scope(object):
"""Starts a nested trace, returns the Trace object."""
# TODO: This follows the __enter__ part of core.new_main.
level = core.thread_local_state.trace_state.trace_stack.next_level()
main = core.MainTrace(level, pe.JaxprTrace)
name_stack = source_info_util.current_name_stack()
main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack)
core.thread_local_state.trace_state.trace_stack.push(main)
self._count_subtraces += 1
return pe.JaxprTrace(main, core.cur_sublevel())
return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack)
def end_subtrace(self):
# TODO: This follows the __exit__ part of core.new_main

View File

@ -910,7 +910,7 @@ def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes)
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
fun = lu.hashable_partial(
lu.wrap_init(ad.backward_pass),
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()))
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()), False)
fun, nz_arg_cts = ad.nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).

View File

@ -846,7 +846,7 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
return tuple(x for x, mz in zip(xs, maybe_zeros) if not type(mz) is ty)
body = lu.wrap_init(ad.closed_backward_pass)
body = lu.hashable_partial(body, jaxpr, reduce_axes)
body = lu.hashable_partial(body, jaxpr, reduce_axes, False)
primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in))
body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
from functools import partial
import itertools as it
@ -40,19 +40,22 @@ map = safe_map
def identity(x): return x
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
transform_stack=True) -> Any:
if not has_aux:
return jvpfun(jvp_subtrace(fun), instantiate)
return jvpfun(jvp_subtrace(fun), instantiate, transform_stack)
else:
fun, aux = jvp_subtrace_aux(fun)
return jvpfun(fun, instantiate), aux
return jvpfun(fun, instantiate, transform_stack), aux
@lu.transformation
def jvpfun(instantiate, primals, tangents):
def jvpfun(instantiate, transform_stack, primals, tangents):
tangents = [Zero.from_value(t) if not isinstance(t, Zero)
and dtype(t) is float0 else t for t in tangents]
with core.new_main(JVPTrace) as main:
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
else contextlib.nullcontext())
with core.new_main(JVPTrace) as main, ctx:
out_primals, out_tangents = yield (main, primals, tangents), {}
del main
if type(instantiate) is bool:
@ -120,7 +123,7 @@ def vjp(traceable, primals, has_aux=False, reduce_axes=()):
def unbound_vjp(pvals, jaxpr, consts, *cts):
cts = tuple(map(ignore_consts, cts, pvals))
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
return map(instantiate_zeros, arg_cts)
# Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
@ -162,7 +165,7 @@ def recast_to_float0(primal, tangent):
return tangent
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents_in):
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in):
if all(type(ct) is Zero for ct in cotangents_in):
return map(lambda v: Zero(v.aval), jaxpr.invars)
@ -207,36 +210,40 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents
map(write_primal, jaxpr.invars, primals_in)
ct_env: Dict[Any, Any] = {}
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
# FIXME: Some invars correspond to tangents
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_in, = map(read_cotangent, eqn.outvars)
with source_info_util.user_context(eqn.source_info.traceback):
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
cts_in_avals = [v.aval for v in eqn.outvars]
params = dict(eqn.params)
call_jaxpr = params.pop('call_jaxpr')
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
elif eqn.primitive in reducing_transposes:
cts_out = reducing_transposes[eqn.primitive](
reduce_axes, cts_in, *invals, **eqn.params)
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
else contextlib.nullcontext())
with ctx:
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
# FIXME: Some invars correspond to tangents
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
**eqn.params)
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
# FIXME: Some invars correspond to primals!
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
cts_in, = map(read_cotangent, eqn.outvars)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
cts_in_avals = [v.aval for v in eqn.outvars]
params = dict(eqn.params)
call_jaxpr = params.pop('call_jaxpr')
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
elif eqn.primitive in reducing_transposes:
cts_out = reducing_transposes[eqn.primitive](
reduce_axes, cts_in, *invals, **eqn.params)
else:
cts_out = get_primitive_transpose(eqn.primitive)(
cts_in, *invals, **eqn.params)
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
# FIXME: Some invars correspond to primals!
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
cotangents_out = map(read_cotangent, jaxpr.invars)
return cotangents_out
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, primals_in, cotangents_in):
return backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals_in, cotangents_in)
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack, primals_in, cotangents_in):
return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts, primals_in, cotangents_in)
class UndefinedPrimal:
@ -297,7 +304,7 @@ class JVPTrace(Trace):
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
nz_tangents = [type(t) is not Zero for t in tangents]
if 'name' in params:
if 'name' in params and not config.jax_experimental_name_stack:
params = dict(params, name=wrap_name(params['name'], 'jvp'))
f_jvp = jvp_subtrace(f, self.main)
f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
@ -547,9 +554,12 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
if config.jax_experimental_name_stack:
new_params = params
else:
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, map(is_undefined_primal, args),
@ -575,7 +585,7 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
# Now that we have a purely linear jaxpr, we can transpose it
cotangents_out = backward_pass(
tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in)
tangent_jaxpr.jaxpr, reduce_axes, False, (), primals_in + residuals, cotangents_in)
# backward_pass will return cotangents computed for all invars, but some of them
# are residuals appended by partial eval, so we need to skip those before we return.
return cotangents_out[:len(primals_in)]
@ -594,7 +604,7 @@ def nonzero_outputs(*args, **kwargs):
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
fun, nz_arg_cts = nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
@ -642,7 +652,8 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate):
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)

View File

@ -22,6 +22,7 @@ import jax
from jax.config import config
from jax import core
from jax.core import raise_to_shaped, Trace, Tracer
from jax._src import source_info_util
from jax._src.tree_util import tree_unflatten, tree_flatten
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero)
@ -29,7 +30,6 @@ from jax import linear_util as lu
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name,
split_list, canonicalize_axis, moveaxis,
as_hashable_function, curry, memoize, cache)
from jax._src import source_info_util
from jax.interpreters import partial_eval as pe
map = safe_map
@ -205,7 +205,10 @@ class BatchTrace(Trace):
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
if config.jax_experimental_name_stack:
params = dict(params, name=params.get('name', f.__name__))
else:
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
@ -372,7 +375,8 @@ def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size,
def _batch_outer(axis_name, axis_size, in_dims, main_type, *in_vals):
with core.new_main(main_type, axis_name=axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main):
outs = yield (main, in_dims, *in_vals), {}
with source_info_util.transform_name_stack('vmap'):
outs = yield (main, in_dims, *in_vals), {}
del main
yield outs

View File

@ -266,8 +266,8 @@ register_constant_handler(
def _source_info_to_location(
primitive: core.Primitive, params: Dict,
source_info: source_info_util.SourceInfo,
name_stack: str = "") -> ir.Location:
eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location:
eqn_str = str(name_stack) + core.str_eqn_compact(primitive.name, params)
frame = source_info_util.user_frame(source_info)
if frame is None:
loc = ir.Location.unknown()
@ -280,6 +280,7 @@ def _source_info_to_location(
# Translation rules
NameStack = Union[str, source_info_util.NameStack]
def make_ir_context() -> ir.Context:
"""Creates an MLIR context suitable for JAX IR."""
@ -334,7 +335,7 @@ class ModuleContext:
symbol_table: ir.SymbolTable
platform: str
axis_context: AxisContext
name_stack: str
name_stack: NameStack
# Cached primitive lowerings.
cached_primitive_lowerings: Dict[Any, builtin.FuncOp]
@ -344,7 +345,7 @@ class ModuleContext:
return self.axis_context.axis_env
def __init__(
self, platform: str, axis_context: AxisContext, name_stack: str,
self, platform: str, axis_context: AxisContext, name_stack: NameStack,
context: Optional[ir.Context] = None,
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
@ -412,7 +413,7 @@ _module_name_regex = re.compile(r"[^\w.-]")
def lower_jaxpr_to_module(
module_name: str, jaxpr: core.ClosedJaxpr, platform: str,
axis_context: AxisContext,
name_stack: str, donated_args: Sequence[bool],
name_stack: NameStack, donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None

View File

@ -97,6 +97,11 @@ class PartialVal(tuple):
class JaxprTrace(Trace):
def __init__(self, *args, name_stack: source_info_util.NameStack):
super().__init__(*args)
self.name_stack = name_stack
def pure(self, val) -> 'JaxprTracer':
return self.new_const(val)
@ -163,7 +168,8 @@ class JaxprTrace(Trace):
tracers = map(self.instantiate_const, tracers)
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
source = source_info_util.current()
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in out_aval]
@ -213,11 +219,11 @@ class JaxprTrace(Trace):
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params,
source_info_util.current())
out_tracers, primitive, staged_params, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
def process_map(self, primitive, f: lu.WrappedFun, tracers, params):
@ -305,8 +311,9 @@ class JaxprTrace(Trace):
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
new_params = update_params(params, [], len(in_tracers))
new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
source_info_util.current())
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -342,8 +349,10 @@ class JaxprTrace(Trace):
for d, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None)
for a in out_avals]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
primitive, staged_params, source_info_util.current())
primitive, staged_params, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -355,6 +364,9 @@ class JaxprTrace(Trace):
return out, (todo, out_axes_transform)
def _current_truncated_name_stack(self):
return source_info_util.current_name_stack()[len(self.name_stack):]
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
instantiate: bool):
@ -394,11 +406,13 @@ class JaxprTrace(Trace):
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts) + len(env)),
source_info_util.current())
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -434,12 +448,14 @@ class JaxprTrace(Trace):
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts) + len(env),
bwd=bwd, out_trees=out_trees),
source_info_util.current())
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -551,7 +567,8 @@ def trace_to_jaxpr(
returned jaxpr takes as inputs the known residual values followed by values
of the originally unknown inputs.
"""
with core.new_main(JaxprTrace) as main:
current_name_stack = source_info_util.current_name_stack()
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
@ -565,7 +582,7 @@ def trace_to_subjaxpr_nounits(
main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
in_pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
trace = JaxprTrace(main, core.cur_sublevel())
trace = main.with_cur_sublevel()
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
@ -1464,11 +1481,13 @@ class DynamicJaxprTrace(core.Trace):
dim_tracers = _get_tracers_only_in_shapes(tracers)
in_avals = _tracers_to_avals(dim_tracers + tracers)
keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers)
name_stack = source_info_util.current_name_stack()
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, in_avals, keep_inputs=keep_inputs)
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
with source_info_util.set_name_stack(name_stack):
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
source_info = source_info_util.current()
env = {v: t for v, t in zip((*jaxpr.constvars, *jaxpr.invars),
(*consts, *dim_tracers, *tracers))
@ -1695,7 +1714,7 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame):
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
@ -1835,7 +1854,7 @@ def _get_tracers_in_shapes(seen: Set[TracerId], in_tracers: Sequence[Tracer]
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
trace = JaxprTrace(main, core.cur_sublevel())
trace = main.with_cur_sublevel()
in_tracers = map(trace.new_arg, pvals)
ans = yield in_tracers, {}
assert isinstance(ans, (list, tuple)), (

View File

@ -54,7 +54,7 @@ from jax._src import device_array
from jax._src import source_info_util
from jax._src import util
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
extend_name_stack, new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
from jax.errors import JAXTypeError
from jax._src import dispatch
@ -1038,7 +1038,7 @@ def lower_parallel_callable(
axis_env = xla.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
name_stack = extend_name_stack(wrap_name(name, 'pmap'))
name_stack = new_name_stack(wrap_name(name, 'pmap'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
module: Union[str, xc.XlaComputation]
@ -2145,7 +2145,7 @@ def lower_mesh_computation(
in_is_gda: Sequence[bool]):
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = extend_name_stack(wrap_name(fun_name, api_name))
name_stack = new_name_stack(wrap_name(fun_name, api_name))
global_axis_sizes = mesh.shape
@ -2236,7 +2236,7 @@ def lower_mesh_computation(
partitions_are_protos=partitions_proto)
return MeshComputation(
name_stack, module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes,
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_gda=in_is_gda)

View File

@ -35,7 +35,7 @@ from jax._src.api_util import (argnums_partial, flatten_axes, flatten_fun,
_ensure_index_tuple)
import jax._src.util as util
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src.util import (extend_name_stack, wrap_name, wraps, safe_map,
from jax._src.util import (new_name_stack, wrap_name, wraps, safe_map,
safe_zip, HashableFunction)
from jax._src.config import config
@ -149,7 +149,7 @@ def _sharded_callable(
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
axis_env = xla.AxisEnv(nrep, (), ())
ctx = xla.TranslationContext(
c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit")))
c, platform, axis_env, new_name_stack(wrap_name(name, "sharded_jit")))
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
built = c.Build(out_tuple)
@ -202,7 +202,7 @@ def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
sub_ctx = ctx.replace(
builder=subc,
name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
out_parts = out_parts_thunk()
assert len(out_parts) == len(out_nodes)
@ -234,7 +234,7 @@ def _sharded_jit_lowering(ctx, *in_nodes,
args.append(ns)
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
core.ClosedJaxpr(call_jaxpr, ()))

View File

@ -42,7 +42,7 @@ from jax.core import (ConcreteArray, ShapedArray,
Literal, str_eqn_compact, abstract_token)
import jax._src.pretty_printer as pp
from jax._src import util
from jax._src.util import (prod, extend_name_stack, wrap_name,
from jax._src.util import (prod, extend_name_stack, new_name_stack, wrap_name,
safe_zip, safe_map, partition_list)
from jax._src.lib import xla_client as xc
from jax.interpreters import partial_eval as pe
@ -101,11 +101,15 @@ tracebacks = {}
def make_op_metadata(primitive: core.Primitive,
params: Dict, *,
source_info: source_info_util.SourceInfo,
name_stack: str = "",
name_stack: Union[str, source_info_util.NameStack] = "",
) -> xc.OpMetadata:
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
if config.jax_experimental_name_stack:
eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
else:
assert isinstance(name_stack, str)
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
tracebacks[eqn_str] = source_info.traceback
frame = source_info_util.user_frame(source_info) if source_info else None
frame = source_info_util.user_frame(source_info)
return xc.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
@ -438,7 +442,7 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False,
filter_tokens=False)
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
name_stack="")
name_stack=new_name_stack())
ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params)
if prim.multiple_results:
ans = xops.Tuple(c, ans)
@ -551,7 +555,7 @@ class TranslationContext:
# with a specific platform in mind.
platform: Optional[str]
axis_env: AxisEnv
name_stack: str
name_stack: Union[str, source_info_util.NameStack]
def replace(self, **kw): return dataclasses.replace(self, **kw)
@ -581,9 +585,15 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
_partitionmap(write, jaxpr.constvars, consts)
_partitionmap(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if config.jax_experimental_name_stack:
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
else:
source_info = eqn.source_info
op_metadata = make_op_metadata(
eqn.primitive, eqn.params, name_stack=ctx.name_stack,
source_info=eqn.source_info)
source_info=source_info)
ctx.builder.set_op_metadata(op_metadata)
in_nodes = _flatmap(read, eqn.invars)
if (ctx.platform is not None and
@ -596,7 +606,9 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
with source_info_util.user_context(eqn.source_info.traceback):
ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
config.jax_experimental_name_stack else ctx)
ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
*in_nodes, **eqn.params)
assert isinstance(ans, collections.abc.Sequence), (ans, eqn)
@ -755,8 +767,8 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
@profiler.annotate_function
def lower_jaxpr_to_xla_module(
fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv,
name_stack: str, tuple_args: bool, donated_invars: Sequence[bool],
replicated_args: Optional[Sequence[bool]],
name_stack: Union[source_info_util.NameStack, str], tuple_args: bool,
donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]],
arg_partitions: Optional[Any],
out_partitions: Optional[Any],
partitions_are_protos: bool = False
@ -1042,7 +1054,7 @@ def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
wrapped_fun = _tuple_output(wrapped_fun)
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
ctx = TranslationContext(c, backend, axis_env, '')
ctx = TranslationContext(c, backend, axis_env, new_name_stack())
outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
if (multiple_results or
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):

View File

@ -7311,7 +7311,13 @@ class NamedCallTest(jtu.JaxTestCase):
return my_test_function(x)
c = jax.xla_computation(f)(2)
self.assertIn("my_test_function", c.as_hlo_text())
if config.jax_experimental_name_stack:
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = True
hlo_text = c.as_hlo_module().to_string(print_opts)
else:
hlo_text = c.as_hlo_text()
self.assertIn("my_test_function", hlo_text)
def test_non_jaxtype_arg(self):
# For the test to fail without the invalid JaxType filter we need to pass

612
tests/name_stack_test.py Normal file
View File

@ -0,0 +1,612 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import core
from jax import lax
from jax import linear_util as lu
from jax.config import config
from jax._src import test_util as jtu
from jax._src import source_info_util
from jax._src.lib import xla_client
config.parse_flags_with_absl()
extend_name_stack = source_info_util.extend_name_stack
def _get_hlo(f):
def wrapped(*args, **kwargs):
c = jax.xla_computation(f)(*args, **kwargs)
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = True
return c.as_hlo_module().to_string(print_opts)
return wrapped
class _EnableNameStackTestCase(jtu.JaxTestCase):
def setUp(self):
self.cfg = config._read("jax_experimental_name_stack")
config.update("jax_experimental_name_stack", True)
def tearDown(self):
config.update("jax_experimental_name_stack", self.cfg)
class NameStackTest(_EnableNameStackTestCase):
def test_trivial_name_stack(self):
def f(x):
return x + 1
jaxpr = jax.make_jaxpr(f)(2).jaxpr
for eqn in jaxpr.eqns:
self.assertEqual(str(eqn.source_info.name_stack), '')
def test_name_call_name_stack(self):
@jax.named_call
def f(x):
return x + 1
jaxpr = jax.make_jaxpr(f)(2).jaxpr
for eqn in jaxpr.eqns:
self.assertEqual(str(eqn.source_info.name_stack), 'f')
def test_manual_name_stack(self):
@extend_name_stack('foo')
def f(x):
return x + 1
jaxpr = jax.make_jaxpr(f)(2).jaxpr
for eqn in jaxpr.eqns:
self.assertEqual(str(eqn.source_info.name_stack), 'foo')
def test_nested_name_stack(self):
@extend_name_stack('foo')
def f(x):
with extend_name_stack('bar'):
return x + 1
jaxpr = jax.make_jaxpr(f)(2).jaxpr
for eqn in jaxpr.eqns:
self.assertEqual(str(eqn.source_info.name_stack), 'foo/bar')
def test_multiple_name_stack(self):
def f(x):
with extend_name_stack('foo'):
y = x + 1
with extend_name_stack('bar'):
with extend_name_stack('baz'):
return y + 1
jaxpr = jax.make_jaxpr(f)(2).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'bar/baz')
def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
@extend_name_stack('foo')
def f(x):
@lu.wrap_init
@extend_name_stack('bar')
def _f(x):
return [x + 1]
return core.call(_f, x)[0]
jaxpr = jax.make_jaxpr(f)(2).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
hlo_text = _get_hlo(f)(2)
self.assertIn('foo/core_call/bar', hlo_text)
def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
@extend_name_stack('foo')
def f(x):
@jax.jit
@extend_name_stack('bar')
def _f(x):
return x + 1
return _f(x)
jaxpr = jax.make_jaxpr(f)(2).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
hlo_text = _get_hlo(f)(2)
self.assertIn('foo/jit(_f)/bar', hlo_text)
def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
@extend_name_stack('foo')
@jax.pmap
def f(x):
with extend_name_stack('bar'):
return x + 1
jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
class NameStackTransformationTest(_EnableNameStackTestCase):
def test_vmap_should_transform_name_stack(self):
@jax.vmap
def f(x):
with extend_name_stack('foo'):
return x + 1
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
def test_vmap_should_transform_inner_name_stacks(self):
@extend_name_stack('foo')
@jax.vmap
def f(x):
with extend_name_stack('bar'):
with extend_name_stack('baz'):
return x + 1
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/vmap(bar)/vmap(baz)')
def test_vmap_should_apply_to_call_jaxpr(self):
@extend_name_stack('foo')
@jax.vmap
def f(x):
@jax.jit
@extend_name_stack('bar')
def _f(x):
return x + 1
return _f(x)
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
hlo_text = _get_hlo(f)(jnp.ones(2))
self.assertIn('foo/vmap(jit(_f))/vmap(bar)', hlo_text)
def test_jvp_should_transform_stacks(self):
def f(x):
with extend_name_stack('bar'):
with extend_name_stack('baz'):
return jnp.square(x)
g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
'foo/jvp(bar)/jvp(baz)')
def test_jvp_should_apply_to_call_jaxpr(self):
@jax.jit
def f(x):
with extend_name_stack('bar'):
with extend_name_stack('baz'):
return jnp.square(x)
g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(
str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack),
'bar/baz')
hlo_text = _get_hlo(g)(1., 1.)
self.assertIn('foo/jvp(jit(f))/jvp(bar)', hlo_text)
def test_grad_should_add_jvp_and_transpose_to_name_stack(self):
@jax.grad
def f(x):
with extend_name_stack('foo'):
return jnp.sin(x)
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack),
'transpose(jvp(foo))')
hlo_text = _get_hlo(f)(1.)
self.assertIn('jvp(foo)/sin', hlo_text)
self.assertIn('jvp(foo)/cos', hlo_text)
self.assertIn('transpose(jvp(foo))/mul', hlo_text)
def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self):
@jax.grad
@extend_name_stack('foo')
@jax.jit
def f(x):
with extend_name_stack('bar'):
return jnp.sin(x)
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'transpose(jvp(foo))')
self.assertEqual(str(
jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
self.assertEqual(str(
jaxpr.eqns[0].params['call_jaxpr'].eqns[1].source_info.name_stack), 'bar')
self.assertEqual(str(
jaxpr.eqns[1].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
hlo_text = _get_hlo(f)(1.)
self.assertIn('jvp(foo)/jvp(jit(f))/jvp(bar)/sin', hlo_text)
self.assertIn('jvp(foo)/jvp(jit(f))/jvp(bar)/cos', hlo_text)
self.assertIn(
'transpose(jvp(foo))/transpose(jvp(jit(f)))/transpose(jvp(bar))/mul',
hlo_text)
class NameStackControlFlowTest(_EnableNameStackTestCase):
def test_while_loop_body_should_not_have_name_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('bar')
def body(x):
return x + 1
@extend_name_stack('bar_cond')
def cond(x):
return x < 5
return lax.while_loop(cond, body, x)
jaxpr = jax.make_jaxpr(f)(0)
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
'bar')
self.assertEqual(str(
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
'bar_cond')
hlo_text = _get_hlo(f)(1.)
self.assertIn('foo/while/body/bar', hlo_text)
self.assertIn('foo/while/cond/bar_cond', hlo_text)
def test_vmap_of_while_loop_should_transform_name_stack(self):
@jax.vmap
@extend_name_stack('foo')
def f(x):
@extend_name_stack('bar')
def body(x):
return x + 1
@extend_name_stack('bar_cond')
def cond(x):
return x < 5
return lax.while_loop(cond, body, x)
jaxpr = jax.make_jaxpr(f)(jnp.arange(2))
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
self.assertEqual(str(
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
'bar')
self.assertEqual(str(
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
'bar_cond')
hlo_text = _get_hlo(f)(jnp.arange(2.))
self.assertIn('vmap(foo)/vmap(while)/vmap(body)/vmap(bar)', hlo_text)
self.assertIn('vmap(foo)/vmap(while)/vmap(cond)/vmap(bar_cond)', hlo_text)
def test_jvp_of_while_loop_transforms_name_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('bar')
def body(x):
return x + 1.
@extend_name_stack('bar_cond')
def cond(x):
return x < 5.
return lax.while_loop(cond, body, x)
g = lambda x, t: jax.jvp(f, (x,), (t,))
jaxpr = jax.make_jaxpr(g)(1., 1.)
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
'bar')
self.assertEqual(str(
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
'bar_cond')
hlo_text = _get_hlo(g)(1., 1.)
self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(bar)', hlo_text)
self.assertIn('jvp(foo)/jvp(while)/jvp(cond)/jvp(bar_cond)', hlo_text)
def test_vmap_of_jvp_of_while_loop_transforms_name_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('bar')
def body(x):
return x + 1.
@extend_name_stack('bar_cond')
def cond(x):
return x < 5.
return lax.while_loop(cond, body, x)
g = jax.vmap(lambda x, t: jax.jvp(f, (x,), (t,)))
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))')
self.assertEqual(str(
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
'bar')
self.assertEqual(str(
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
'bar_cond')
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
self.assertIn(
'vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(body))/vmap(jvp(bar))',
hlo_text)
self.assertIn(
'vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(cond))/vmap(jvp(bar_cond))',
hlo_text)
def test_cond_body_should_not_have_name_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('true')
def true_fn(x):
return x + 1
@extend_name_stack('false')
def false_fn(x):
return x - 1
return lax.cond(True, true_fn, false_fn, x)
jaxpr = jax.make_jaxpr(f)(0)
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
'false')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
'true')
hlo_text = _get_hlo(f)(1.)
self.assertIn('foo/cond/branch_0_fun/false/sub', hlo_text)
self.assertIn('foo/cond/branch_1_fun/true/add', hlo_text)
def test_vmap_of_cond_should_transform_name_stack(self):
@extend_name_stack('foo')
@jax.vmap
def f(x):
@extend_name_stack('true')
def true_fn(x):
return x + 1
@extend_name_stack('false')
def false_fn(x):
return x - 1
return lax.cond(True, true_fn, false_fn, x)
jaxpr = jax.make_jaxpr(f)(jnp.arange(2))
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
'false')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
'true')
hlo_text = _get_hlo(f)(jnp.arange(2.))
self.assertIn('foo/vmap(cond)/vmap(branch_0_fun)/vmap(false)/sub', hlo_text)
self.assertIn('foo/vmap(cond)/vmap(branch_1_fun)/vmap(true)/add', hlo_text)
def test_jvp_of_cond_transforms_name_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('true')
def true_fn(x):
return x + 1
@extend_name_stack('false')
def false_fn(x):
return x - 1
return lax.cond(True, true_fn, false_fn, x)
g = lambda x, t: jax.jvp(f, (x,), (t,))
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
'false')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
'true')
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
self.assertIn('jvp(foo)/jvp(cond)/jvp(branch_0_fun)/jvp(false)/sub', hlo_text)
self.assertIn('jvp(foo)/jvp(cond)/jvp(branch_1_fun)/jvp(true)/add', hlo_text)
def test_vmap_of_jvp_of_cond_transforms_name_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('true')
def true_fn(x):
return x + 1
@extend_name_stack('false')
def false_fn(x):
return x - 1
return lax.cond(True, true_fn, false_fn, x)
g = jax.vmap(lambda x, t: jax.jvp(f, (x,), (t,)))
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
'false')
self.assertEqual(str(
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
'true')
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
self.assertIn(
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_0_fun))/vmap(jvp(false))/sub',
hlo_text)
self.assertIn(
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_1_fun))/vmap(jvp(true))/add',
hlo_text)
def test_grad_of_cond_transforms_name_stack(self):
@jax.grad
@extend_name_stack('foo')
def f(x):
@extend_name_stack('true')
def true_fn(x):
return x * 2.
@extend_name_stack('false')
def false_fn(x):
return x / 2.
return lax.cond(True, true_fn, false_fn, x)
jaxpr = jax.make_jaxpr(f)(1.)
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack),
'transpose(jvp(foo))')
hlo_text = _get_hlo(f)(1.)
self.assertIn(
'jvp(foo)/jvp(cond)/jvp(branch_0_fun)/jvp(false)/div',
hlo_text)
self.assertIn(
'jvp(foo)/jvp(cond)/jvp(branch_1_fun)/jvp(true)/mul',
hlo_text)
self.assertIn(
'transpose(jvp(foo))/transpose(jvp(cond))/transpose(jvp(branch_0_fun))/transpose(jvp(false))/div',
hlo_text)
self.assertIn(
'transpose(jvp(foo))/transpose(jvp(cond))/transpose(jvp(branch_1_fun))/transpose(jvp(true))/mul',
hlo_text)
def test_vmap_of_grad_of_cond_transforms_name_stack(self):
@jax.vmap
@jax.grad
@extend_name_stack('foo')
def f(x):
@extend_name_stack('true')
def true_fn(x):
return x * 2.
@extend_name_stack('false')
def false_fn(x):
return x / 2.
return lax.cond(True, true_fn, false_fn, x)
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))')
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack),
'vmap(transpose(jvp(foo)))')
hlo_text = _get_hlo(f)(jnp.arange(2.))
self.assertIn(
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_0_fun))/vmap(jvp(false))/div',
hlo_text)
self.assertIn(
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_1_fun))/vmap(jvp(true))/mul',
hlo_text)
self.assertIn(
'vmap(transpose(jvp(foo)))/vmap(transpose(jvp(cond)))/vmap(transpose(jvp(branch_0_fun)))/vmap(transpose(jvp(false)))/div',
hlo_text)
self.assertIn(
'vmap(transpose(jvp(foo)))/vmap(transpose(jvp(cond)))/vmap(transpose(jvp(branch_1_fun)))/vmap(transpose(jvp(true)))/mul',
hlo_text)
def test_scan_body_should_not_have_name_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('scan_body')
def body(carry, x):
return carry + x, carry + x
return lax.scan(body, x, jnp.arange(5.))
jaxpr = jax.make_jaxpr(f)(1.)
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(str(
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
'scan_body')
hlo_text = _get_hlo(f)(1.)
self.assertIn('foo/while/body/scan_body', hlo_text)
def test_vmap_of_scan_should_transform_stack(self):
@jax.vmap
@extend_name_stack('foo')
def f(x):
@extend_name_stack('scan_body')
def body(carry, x):
return carry + x, carry + x
return lax.scan(body, x, jnp.arange(8.))
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
self.assertEqual(str(
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
'scan_body')
hlo_text = _get_hlo(f)(jnp.arange(2.))
self.assertIn('vmap(foo)/vmap(while)/vmap(body)/vmap(scan_body)/add', hlo_text)
def test_jvp_of_scan_should_transform_stack(self):
@extend_name_stack('foo')
def f(x):
@extend_name_stack('scan_body')
def body(carry, x):
return carry + x, carry + x
return lax.scan(body, x, jnp.arange(8.))
g = lambda x, t: jax.jvp(f, (x,), (t,))
jaxpr = jax.make_jaxpr(g)(1., 1.)
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
'scan_body')
hlo_text = _get_hlo(g)(1., 1.)
self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(scan_body)/add', hlo_text)
def test_grad_of_scan_should_transform_stack(self):
@jax.grad
@extend_name_stack('foo')
def f(x):
@extend_name_stack('scan_body')
def body(carry, x):
return carry * x, carry + x
return lax.scan(body, x, jnp.arange(8.))[0]
jaxpr = jax.make_jaxpr(f)(1.)
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)')
self.assertEqual(str(jaxpr.eqns[3].source_info.name_stack),
'transpose(jvp(foo))')
self.assertEqual(str(
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
'scan_body')
hlo_text = _get_hlo(f)(1.)
self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(scan_body)/mul', hlo_text)
self.assertIn('transpose(jvp(foo))/transpose(jvp(while))/transpose(jvp(body))/transpose(jvp(scan_body))/mul', hlo_text)
def test_vmap_of_grad_of_scan_should_transform_stack(self):
@jax.vmap
@jax.grad
@extend_name_stack('foo')
def f(x):
@extend_name_stack('scan_body')
def body(carry, x):
return carry * x, carry + x
return lax.scan(body, x, jnp.arange(8.))[0]
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'vmap(jvp(foo))')
self.assertEqual(str(jaxpr.eqns[3].source_info.name_stack),
'vmap(transpose(jvp(foo)))')
self.assertEqual(str(
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
'scan_body')
hlo_text = _get_hlo(f)(jnp.arange(2.))
self.assertIn('vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(body))/vmap(jvp(scan_body))/mul', hlo_text)
self.assertIn('vmap(transpose(jvp(foo)))/vmap(transpose(jvp(while)))/vmap(transpose(jvp(body)))/vmap(transpose(jvp(scan_body)))/mul', hlo_text)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())