mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add separate mechanism for threading name stacks to the lowering
This commit is contained in:
parent
e96b91d405
commit
1b79caa6bd
@ -353,7 +353,7 @@ def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
|
||||
primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
|
||||
tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
|
||||
dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
|
||||
in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, consts, dummy_args,
|
||||
in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
|
||||
out_cts)
|
||||
in_cts, cell.treedef = tree_flatten(in_cts_)
|
||||
return in_cts
|
||||
|
@ -56,10 +56,11 @@ from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
Partial, PyTreeDef, all_leaves)
|
||||
from jax._src.tree_util import broadcast_prefix
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, wrap_name, cache, wraps,
|
||||
extend_name_stack, new_name_stack, wrap_name, cache, wraps,
|
||||
HashableFunction)
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import source_info_util
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -895,9 +896,8 @@ def xla_computation(fun: Callable,
|
||||
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
|
||||
xla_args, donated_invars = xla._xla_callable_args(
|
||||
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
|
||||
ctx = xla.TranslationContext(
|
||||
c, backend, axis_env_,
|
||||
extend_name_stack(wrap_name(fun_name, "xla_computation")))
|
||||
name_stack = new_name_stack(wrap_name(fun_name, "xla_computation"))
|
||||
ctx = xla.TranslationContext(c, backend, axis_env_, name_stack)
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
|
||||
if out_parts is not None:
|
||||
@ -2615,7 +2615,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
|
||||
in_cotangents = map(
|
||||
ad.instantiate_zeros,
|
||||
ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents))
|
||||
ad.backward_pass(jaxpr, reduce_axes, True, consts, dummies, out_cotangents))
|
||||
return tree_unflatten(in_tree, in_cotangents)
|
||||
|
||||
# Ensure that transposed_fun is a PyTree
|
||||
@ -3197,6 +3197,9 @@ def named_call(
|
||||
|
||||
_, in_tree = tree_flatten(())
|
||||
|
||||
if config.jax_experimental_name_stack:
|
||||
return source_info_util.extend_name_stack(name)(fun)
|
||||
|
||||
@functools.wraps(fun)
|
||||
def named_call_f(*args, **kwargs):
|
||||
lu_f = lu.wrap_init(lambda: fun(*args, **kwargs))
|
||||
|
@ -140,7 +140,6 @@ class Config:
|
||||
for name, val in self.values.items():
|
||||
flag_type, meta_args, meta_kwargs = self.meta[name]
|
||||
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
|
||||
|
||||
app.call_after_init(lambda: self.complete_absl_config(absl_flags))
|
||||
|
||||
def complete_absl_config(self, absl_flags):
|
||||
@ -688,6 +687,11 @@ config.define_bool_state(
|
||||
help=('Enables experimental features for staging out computations with '
|
||||
'dynamic shapes.'))
|
||||
|
||||
config.define_bool_state(
|
||||
name='jax_experimental_name_stack',
|
||||
default=False,
|
||||
help='Enable using the context manager-based name stack.')
|
||||
|
||||
# This flag is temporary during rollout of the remat barrier.
|
||||
# TODO(parkers): Remove if there are no complaints.
|
||||
config.define_bool_state(
|
||||
|
@ -407,7 +407,7 @@ def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
|
||||
jvp_jaxpr_thunk, num_consts):
|
||||
del jvp_jaxpr_thunk, num_consts
|
||||
return ad.backward_pass(
|
||||
fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts)
|
||||
fun_jaxpr.jaxpr, reduce_axes, False, fun_jaxpr.consts, args, cts)
|
||||
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
|
||||
|
||||
def custom_jvp_jaxpr_custom_partial_eval_rule(
|
||||
|
@ -250,7 +250,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
# pass long arg lists as tuple for TPU
|
||||
tuple_args = len(abstract_args) > 100
|
||||
axis_env = xla.AxisEnv(nreps, (), ())
|
||||
name_stack = xla.extend_name_stack(xla.wrap_name(name, 'jit'))
|
||||
name_stack = xla.new_name_stack(xla.wrap_name(name, 'jit'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"jit_{fun.__name__}"
|
||||
|
@ -347,8 +347,9 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry))
|
||||
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
|
||||
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
|
||||
name_stack = extend_name_stack(ctx.name_stack, 'while')
|
||||
cond_ctx = ctx.replace(builder=cond_c,
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'cond'))
|
||||
name_stack=extend_name_stack(name_stack, 'cond'))
|
||||
pred, = xla.jaxpr_subcomp(
|
||||
cond_ctx, cond_jaxpr.jaxpr,
|
||||
_map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts),
|
||||
@ -365,14 +366,14 @@ def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
|
||||
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
|
||||
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
|
||||
body_ctx = ctx.replace(builder=body_c,
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'body'))
|
||||
name_stack=extend_name_stack(name_stack, 'body'))
|
||||
new_z = xla.jaxpr_subcomp(
|
||||
body_ctx, body_jaxpr.jaxpr,
|
||||
_map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts),
|
||||
*(y + z))
|
||||
if batched:
|
||||
body_pred_ctx = body_ctx.replace(
|
||||
name_stack=extend_name_stack(ctx.name_stack, 'body_pred'))
|
||||
name_stack=extend_name_stack(name_stack, 'body_pred'))
|
||||
body_pred, = xla.jaxpr_subcomp(
|
||||
body_pred_ctx, cond_jaxpr.jaxpr,
|
||||
_map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts),
|
||||
@ -1201,9 +1202,11 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
|
||||
linear_2 = (False,) * num_res + linear
|
||||
params = dict(branches=branches_2, linear=linear_2)
|
||||
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = pe.new_eqn_recipe(
|
||||
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
|
||||
source_info_util.current())
|
||||
source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
@ -1297,7 +1300,7 @@ def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
||||
res, cts_out = split_list(args, [num_res])
|
||||
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
||||
cts_in = ad.backward_pass(
|
||||
jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out)
|
||||
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
|
||||
_, cts_in = split_list(cts_in, [num_res])
|
||||
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
||||
|
||||
@ -1924,9 +1927,10 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
for uk, t in zip(unknowns[:num_consts], tracers[:num_consts])]
|
||||
other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
|
||||
in_pvals_1 = invariant_pvals + other_pvals
|
||||
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
|
||||
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
|
||||
with source_info_util.reset_name_stack():
|
||||
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
|
||||
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
|
||||
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)
|
||||
jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ())
|
||||
num_consts_1 = num_consts + len(consts_1)
|
||||
# any now-known residuals are intensive, so we want to revise jaxpr_2 to take
|
||||
@ -1990,6 +1994,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
ext_res_tracers = _map(trace.new_instantiated_const, extensive_residuals)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
|
||||
for pv, const in zip(out_pvs, out_consts)]
|
||||
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
linear_2 = ([False] * len(int_res_tracers) +
|
||||
[lin or not uk for uk, lin in zip(unknowns, linear)] +
|
||||
[False] * len(ext_res_tracers))
|
||||
@ -1999,7 +2005,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
num_consts=num_consts_2,
|
||||
num_carry=num_carry, linear=tuple(linear_2),
|
||||
unroll=unroll),
|
||||
source_info_util.current())
|
||||
source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
@ -2068,7 +2074,7 @@ def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
|
||||
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
|
||||
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
|
||||
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
|
||||
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts,
|
||||
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
|
||||
primals, b_bar)
|
||||
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
|
||||
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
|
||||
|
@ -50,7 +50,7 @@ from jax.interpreters import masking
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src import util
|
||||
from jax._src.util import (cache, safe_zip, prod, safe_map, canonicalize_axis,
|
||||
split_list)
|
||||
split_list, new_name_stack)
|
||||
from jax.tree_util import tree_map
|
||||
import jax._src.lib
|
||||
from jax._src.lib import pytree
|
||||
@ -3424,7 +3424,7 @@ def _reduction_computation(ctx, jaxpr, consts, init_values, singleton=True):
|
||||
subc = xc.XlaBuilder("reduction_computation")
|
||||
assert len(consts) == 0, "Reduction computations cannot have constants"
|
||||
args = [xla.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
|
||||
ctx = xla.TranslationContext(subc, platform, axis_env, '')
|
||||
ctx = xla.TranslationContext(subc, platform, axis_env, new_name_stack())
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, consts, *args)
|
||||
if singleton:
|
||||
return subc.build(out_nodes[0])
|
||||
|
@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import os.path
|
||||
import threading
|
||||
import types
|
||||
from typing import Optional, Iterator, NamedTuple
|
||||
from typing import Optional, Iterator, NamedTuple, Union, Tuple
|
||||
|
||||
import jax.version
|
||||
from jax._src.lib import xla_client, xla_extension_version
|
||||
@ -40,15 +41,66 @@ _exclude_paths = [os.path.dirname(jax.version.__file__)]
|
||||
def register_exclusion(path):
|
||||
_exclude_paths.append(path)
|
||||
|
||||
class Scope(NamedTuple):
|
||||
name: str
|
||||
|
||||
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
|
||||
return (self.name, *stack)
|
||||
|
||||
class Transform(NamedTuple):
|
||||
name: str
|
||||
|
||||
def wrap(self, stack: Tuple[str, ...]) -> Tuple[str, ...]:
|
||||
return tuple(map(lambda x: f'{self.name}({x})', stack))
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NameStack:
|
||||
stack: Tuple[Union[Scope, Transform], ...] = ()
|
||||
|
||||
def extend(self, name: Union[Tuple[str, ...], str]) -> 'NameStack':
|
||||
if not isinstance(name, tuple):
|
||||
name = (name,)
|
||||
scopes = tuple(map(Scope, name))
|
||||
return NameStack(self.stack + scopes)
|
||||
|
||||
def wrap_name(self, name: str) -> str:
|
||||
if not self.stack:
|
||||
return name
|
||||
return f'{str(self)}/{name}'
|
||||
|
||||
def transform(self, transform_name: str) -> 'NameStack':
|
||||
return NameStack((*self.stack, Transform(transform_name)))
|
||||
|
||||
def __getitem__(self, idx) -> 'NameStack':
|
||||
return NameStack(self.stack[idx])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.stack)
|
||||
|
||||
def __add__(self, other: 'NameStack') -> 'NameStack':
|
||||
return NameStack(self.stack + other.stack)
|
||||
|
||||
def __radd__(self, other: 'NameStack') -> 'NameStack':
|
||||
return NameStack(other.stack + self.stack)
|
||||
|
||||
def __str__(self) -> str:
|
||||
scope: Tuple[str, ...] = ()
|
||||
for elem in self.stack[::-1]:
|
||||
scope = elem.wrap(scope)
|
||||
return '/'.join(scope)
|
||||
|
||||
class SourceInfo(NamedTuple):
|
||||
traceback: Optional[Traceback]
|
||||
name_stack: NameStack
|
||||
|
||||
def replace(self, *, traceback: Optional[Traceback] = None) -> 'SourceInfo':
|
||||
def replace(self, *, traceback: Optional[Traceback] = None,
|
||||
name_stack: Optional[NameStack] = None) -> 'SourceInfo':
|
||||
traceback = traceback or self.traceback
|
||||
return self._replace(traceback=traceback)
|
||||
name_stack = self.name_stack if name_stack is None else name_stack
|
||||
return self._replace(traceback=traceback, name_stack=name_stack)
|
||||
|
||||
def new_source_info() -> SourceInfo:
|
||||
return SourceInfo(None)
|
||||
return SourceInfo(None, NameStack())
|
||||
|
||||
def is_user_filename(filename: str) -> bool:
|
||||
"""Heuristic that guesses the identity of the user's code in a stack trace."""
|
||||
@ -97,11 +149,10 @@ class _SourceInfoContext(threading.local):
|
||||
_source_info_context = _SourceInfoContext()
|
||||
|
||||
def current() -> SourceInfo:
|
||||
context = _source_info_context.context
|
||||
if not context.traceback:
|
||||
return context.replace(traceback=xla_client.Traceback.get_traceback())
|
||||
return context
|
||||
|
||||
source_info = _source_info_context.context
|
||||
if not source_info.traceback:
|
||||
source_info = source_info.replace(traceback=xla_client.Traceback.get_traceback())
|
||||
return source_info
|
||||
|
||||
class JaxStackTraceBeforeTransformation(Exception): pass
|
||||
|
||||
@ -118,9 +169,10 @@ def has_user_context(e):
|
||||
return False
|
||||
|
||||
@contextlib.contextmanager
|
||||
def user_context(c: Optional[Traceback]):
|
||||
def user_context(c: Optional[Traceback], *, name_stack: Optional[NameStack] = None):
|
||||
prev = _source_info_context.context
|
||||
_source_info_context.context = _source_info_context.context.replace(traceback=c)
|
||||
_source_info_context.context = _source_info_context.context.replace(
|
||||
traceback=c, name_stack=name_stack)
|
||||
filtered_tb = None
|
||||
try:
|
||||
yield
|
||||
@ -141,3 +193,43 @@ def user_context(c: Optional[Traceback]):
|
||||
finally:
|
||||
_source_info_context.context = prev
|
||||
del filtered_tb
|
||||
|
||||
def current_name_stack() -> NameStack:
|
||||
return _source_info_context.context.name_stack
|
||||
|
||||
@contextlib.contextmanager
|
||||
def extend_name_stack(name: str) -> Iterator[NameStack]:
|
||||
prev_context = _source_info_context.context
|
||||
curr_name_stack = prev_context.name_stack
|
||||
new_context = prev_context.replace(name_stack=curr_name_stack.extend(name))
|
||||
_source_info_context.context = new_context
|
||||
try:
|
||||
yield _source_info_context.context.name_stack
|
||||
finally:
|
||||
_source_info_context.context = prev_context
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_name_stack(name_stack: NameStack) -> Iterator[None]:
|
||||
prev_context = _source_info_context.context
|
||||
new_context = prev_context.replace(name_stack=name_stack)
|
||||
_source_info_context.context = new_context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_source_info_context.context = prev_context
|
||||
|
||||
@contextlib.contextmanager
|
||||
def reset_name_stack() -> Iterator[None]:
|
||||
with set_name_stack(NameStack()):
|
||||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def transform_name_stack(name: str) -> Iterator[NameStack]:
|
||||
prev_context = _source_info_context.context
|
||||
curr_name_stack = prev_context.name_stack
|
||||
new_context = prev_context.replace(name_stack=curr_name_stack.transform(name))
|
||||
_source_info_context.context = new_context
|
||||
try:
|
||||
yield _source_info_context.context.name_stack
|
||||
finally:
|
||||
_source_info_context.context = prev_context
|
||||
|
@ -277,7 +277,21 @@ def get_module_functions(module):
|
||||
def wrap_name(name, transform_name):
|
||||
return transform_name + '(' + name + ')'
|
||||
|
||||
def extend_name_stack(stack, name=''):
|
||||
def new_name_stack(name: str = ''):
|
||||
if config.jax_experimental_name_stack:
|
||||
from jax._src import source_info_util
|
||||
name_stack = source_info_util.NameStack()
|
||||
if name:
|
||||
name_stack = name_stack.extend(name)
|
||||
return name_stack
|
||||
return name + '/'
|
||||
|
||||
def extend_name_stack(stack, name: str):
|
||||
if config.jax_experimental_name_stack:
|
||||
from jax._src import source_info_util
|
||||
assert isinstance(stack, source_info_util.NameStack), stack
|
||||
return stack.extend(name)
|
||||
assert isinstance(stack, str)
|
||||
return stack + name + '/'
|
||||
|
||||
def canonicalize_axis(axis, num_dims) -> int:
|
||||
|
50
jax/core.py
50
jax/core.py
@ -81,10 +81,11 @@ class Jaxpr:
|
||||
__repr__ = __str__
|
||||
|
||||
def pretty_print(self, *, source_info=False, print_shapes=True,
|
||||
custom_pp_eqn_rules=True, **kw):
|
||||
custom_pp_eqn_rules=True, name_stack=False, **kw):
|
||||
doc = pp_jaxpr(self, JaxprPpContext(), source_info=source_info,
|
||||
print_shapes=print_shapes,
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules)
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules,
|
||||
name_stack=name_stack)
|
||||
return doc.format(**kw)
|
||||
|
||||
def _repr_pretty_(self, p, cycle):
|
||||
@ -141,9 +142,10 @@ class ClosedJaxpr:
|
||||
def __str__(self): return str(self.jaxpr)
|
||||
def __repr__(self): return repr(self.jaxpr)
|
||||
|
||||
def pretty_print(self, *, source_info=False, print_shapes=True, **kw):
|
||||
def pretty_print(self, *, source_info=False, print_shapes=True,
|
||||
name_stack=False, **kw):
|
||||
return pp_jaxpr(self.jaxpr, JaxprPpContext(), source_info=source_info,
|
||||
print_shapes=print_shapes).format(**kw)
|
||||
print_shapes=print_shapes, name_stack=name_stack).format(**kw)
|
||||
|
||||
|
||||
def _repr_pretty_(self, p, cycle):
|
||||
@ -333,7 +335,8 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
|
||||
map(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
with source_info_util.user_context(eqn.source_info.traceback):
|
||||
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
|
||||
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write, eqn.outvars, ans)
|
||||
@ -2272,50 +2275,52 @@ def pp_vars(vs: Sequence[Any], context: JaxprPpContext,
|
||||
[pp.text(pp_var(v, context)) for v in vs])
|
||||
))
|
||||
|
||||
def pp_kv_pair(k:str, v: Any, context: JaxprPpContext) -> pp.Doc:
|
||||
def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc:
|
||||
if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
|
||||
pp_v = pp_jaxprs(v, context)
|
||||
pp_v = pp_jaxprs(v, context, name_stack=name_stack)
|
||||
elif isinstance(v, Jaxpr):
|
||||
pp_v = pp_jaxpr(v, context)
|
||||
pp_v = pp_jaxpr(v, context, name_stack=name_stack)
|
||||
elif isinstance(v, ClosedJaxpr):
|
||||
pp_v = pp_jaxpr(v.jaxpr, context)
|
||||
pp_v = pp_jaxpr(v.jaxpr, context, name_stack=name_stack)
|
||||
else:
|
||||
pp_v = pp.text(str(v))
|
||||
return pp.text(f'{k}=') + pp_v
|
||||
|
||||
def pp_kv_pairs(kv_pairs, context: JaxprPpContext) -> pp.Doc:
|
||||
def pp_kv_pairs(kv_pairs, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc:
|
||||
if not kv_pairs:
|
||||
return pp.nil()
|
||||
return pp.group(
|
||||
pp.nest(2, pp.concat([
|
||||
pp.text("["), pp.brk(""),
|
||||
pp.join(pp.brk(), [pp_kv_pair(k, v, context) for k, v in kv_pairs])
|
||||
pp.join(pp.brk(), [pp_kv_pair(k, v, context, name_stack=name_stack) for k, v in kv_pairs])
|
||||
]))
|
||||
+ pp.brk("") + pp.text("]")
|
||||
)
|
||||
|
||||
def pp_eqn(eqn, context: JaxprPpContext, *, print_shapes=True,
|
||||
source_info=False, custom_pp_eqn_rules=True) -> pp.Doc:
|
||||
source_info=False, custom_pp_eqn_rules=True, name_stack=False) -> pp.Doc:
|
||||
lhs = pp_vars(eqn.outvars, context, print_shapes=print_shapes)
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if source_info else None)
|
||||
rule = pp_eqn_rules.get(eqn.primitive)
|
||||
name_stack_annotation = f'[{eqn.source_info.name_stack}]' if name_stack else None
|
||||
if rule and custom_pp_eqn_rules:
|
||||
rhs = rule(eqn, context)
|
||||
else:
|
||||
rhs = [pp.text(eqn.primitive.name),
|
||||
pp_kv_pairs(sorted(eqn.params.items()), context),
|
||||
rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation),
|
||||
pp_kv_pairs(sorted(eqn.params.items()), context, name_stack=name_stack),
|
||||
pp.text(" ") + pp_vars(eqn.invars, context)]
|
||||
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
|
||||
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext], Sequence[pp.Doc]]
|
||||
pp_eqn_rules: Dict[Primitive, CustomPpEqnRule] = {}
|
||||
|
||||
def pp_eqns(eqns, context: JaxprPpContext, *, print_shapes=True,
|
||||
source_info=False, custom_pp_eqn_rules=True
|
||||
source_info=False, custom_pp_eqn_rules=True, name_stack=False,
|
||||
) -> pp.Doc:
|
||||
return pp.join(
|
||||
pp.brk("; "),
|
||||
[pp_eqn(e, context, print_shapes=print_shapes, source_info=source_info,
|
||||
name_stack=name_stack,
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules) for e in eqns])
|
||||
|
||||
def _compact_eqn_should_include(k: str, v: Any) -> bool:
|
||||
@ -2349,23 +2354,25 @@ def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext, *,
|
||||
|
||||
|
||||
def pp_jaxpr(jaxpr, context: JaxprPpContext, *, print_shapes=True,
|
||||
source_info=False, custom_pp_eqn_rules=True) -> pp.Doc:
|
||||
source_info=False, custom_pp_eqn_rules=True, name_stack=False) -> pp.Doc:
|
||||
eqns_fn = lambda: pp_eqns(jaxpr.eqns, context, print_shapes=print_shapes,
|
||||
source_info=source_info,
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules)
|
||||
custom_pp_eqn_rules=custom_pp_eqn_rules,
|
||||
name_stack=name_stack)
|
||||
return pp_jaxpr_skeleton(jaxpr, eqns_fn, context, print_shapes=print_shapes)
|
||||
|
||||
def pp_jaxprs(jaxprs, context: JaxprPpContext) -> pp.Doc:
|
||||
def pp_jaxprs(jaxprs, context: JaxprPpContext, name_stack: bool = False) -> pp.Doc:
|
||||
jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
|
||||
return pp.group(pp.nest(2, pp.concat([
|
||||
pp.text('('), pp.brk(""),
|
||||
pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context), jaxprs))]
|
||||
pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, name_stack=name_stack), jaxprs))]
|
||||
)) + pp.brk("") + pp.text(')')
|
||||
)
|
||||
|
||||
|
||||
def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext,
|
||||
print_shapes=True, source_info: bool = False) -> pp.Doc:
|
||||
print_shapes=True, source_info: bool = False,
|
||||
name_stack: bool = False) -> pp.Doc:
|
||||
lo = max(lo, 0)
|
||||
hi = max(lo, min(hi, len(jaxpr.eqns)))
|
||||
eqns = jaxpr.eqns[lo:hi]
|
||||
@ -2377,7 +2384,8 @@ def pp_jaxpr_eqn_range(jaxpr: Jaxpr, lo: int, hi: int, context: JaxprPpContext,
|
||||
if lo != 0:
|
||||
pps.append(pp.text('...'))
|
||||
pps.extend(map((lambda e: pp_eqn(e, context, print_shapes=print_shapes,
|
||||
source_info=source_info)), eqns))
|
||||
source_info=source_info,
|
||||
name_stack=name_stack)), eqns))
|
||||
if hi != len(jaxpr.eqns):
|
||||
pps.append(pp.text('...'))
|
||||
return pp.join(pp.brk("; "), pps)
|
||||
|
@ -24,7 +24,7 @@ from jax._src import source_info_util
|
||||
from jax.core import Var, Literal, Atom, Tracer
|
||||
from jax._src import util
|
||||
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
|
||||
tuple_delete)
|
||||
tuple_delete, new_name_stack)
|
||||
import jax._src.pretty_printer as pp
|
||||
|
||||
map = safe_map
|
||||
@ -806,7 +806,8 @@ def traceable_to_padded_translation(traceable):
|
||||
|
||||
operands_ = it.chain.from_iterable([*dims.values(), *operands])
|
||||
platform = "cpu" # TODO: don't hardwire in the CPU translation.
|
||||
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(1, (), ()), '')
|
||||
ctx = xla.TranslationContext(c, platform, xla.AxisEnv(1, (), ()),
|
||||
new_name_stack())
|
||||
outs = xla.jaxpr_subcomp(ctx, jaxpr, xla._xla_consts(c, consts), *operands_)
|
||||
return util.unflatten(outs,
|
||||
[aval_to_num_buffers(aval) for aval in out_avals])
|
||||
|
@ -261,7 +261,7 @@ def convert(fun: Callable,
|
||||
"""
|
||||
api._check_callable(fun)
|
||||
fun_name = getattr(fun, "__name__", "unknown")
|
||||
name_stack = util.extend_name_stack(util.wrap_name(fun_name, "jax2tf"))
|
||||
name_stack = util.wrap_name(fun_name, "jax2tf") + "/"
|
||||
def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
|
||||
# TODO: is there a better way to check if we are inside a transformation?
|
||||
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
|
||||
|
@ -117,6 +117,7 @@ from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax import tree_util
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import source_info_util
|
||||
from jax._src.util import safe_map
|
||||
|
||||
|
||||
@ -291,10 +292,11 @@ class Scope(object):
|
||||
"""Starts a nested trace, returns the Trace object."""
|
||||
# TODO: This follows the __enter__ part of core.new_main.
|
||||
level = core.thread_local_state.trace_state.trace_stack.next_level()
|
||||
main = core.MainTrace(level, pe.JaxprTrace)
|
||||
name_stack = source_info_util.current_name_stack()
|
||||
main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack)
|
||||
core.thread_local_state.trace_state.trace_stack.push(main)
|
||||
self._count_subtraces += 1
|
||||
return pe.JaxprTrace(main, core.cur_sublevel())
|
||||
return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack)
|
||||
|
||||
def end_subtrace(self):
|
||||
# TODO: This follows the __exit__ part of core.new_main
|
||||
|
@ -910,7 +910,7 @@ def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes)
|
||||
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
|
||||
fun = lu.hashable_partial(
|
||||
lu.wrap_init(ad.backward_pass),
|
||||
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()))
|
||||
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()), False)
|
||||
fun, nz_arg_cts = ad.nonzero_outputs(fun)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
||||
|
@ -846,7 +846,7 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if not type(mz) is ty)
|
||||
|
||||
body = lu.wrap_init(ad.closed_backward_pass)
|
||||
body = lu.hashable_partial(body, jaxpr, reduce_axes)
|
||||
body = lu.hashable_partial(body, jaxpr, reduce_axes, False)
|
||||
primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in))
|
||||
body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef)
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
@ -40,19 +40,22 @@ map = safe_map
|
||||
def identity(x): return x
|
||||
|
||||
|
||||
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
|
||||
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
|
||||
transform_stack=True) -> Any:
|
||||
if not has_aux:
|
||||
return jvpfun(jvp_subtrace(fun), instantiate)
|
||||
return jvpfun(jvp_subtrace(fun), instantiate, transform_stack)
|
||||
else:
|
||||
fun, aux = jvp_subtrace_aux(fun)
|
||||
return jvpfun(fun, instantiate), aux
|
||||
return jvpfun(fun, instantiate, transform_stack), aux
|
||||
|
||||
|
||||
@lu.transformation
|
||||
def jvpfun(instantiate, primals, tangents):
|
||||
def jvpfun(instantiate, transform_stack, primals, tangents):
|
||||
tangents = [Zero.from_value(t) if not isinstance(t, Zero)
|
||||
and dtype(t) is float0 else t for t in tangents]
|
||||
with core.new_main(JVPTrace) as main:
|
||||
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
|
||||
else contextlib.nullcontext())
|
||||
with core.new_main(JVPTrace) as main, ctx:
|
||||
out_primals, out_tangents = yield (main, primals, tangents), {}
|
||||
del main
|
||||
if type(instantiate) is bool:
|
||||
@ -120,7 +123,7 @@ def vjp(traceable, primals, has_aux=False, reduce_axes=()):
|
||||
def unbound_vjp(pvals, jaxpr, consts, *cts):
|
||||
cts = tuple(map(ignore_consts, cts, pvals))
|
||||
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
|
||||
arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
|
||||
arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
|
||||
return map(instantiate_zeros, arg_cts)
|
||||
|
||||
# Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
|
||||
@ -162,7 +165,7 @@ def recast_to_float0(primal, tangent):
|
||||
return tangent
|
||||
|
||||
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
|
||||
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents_in):
|
||||
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in):
|
||||
if all(type(ct) is Zero for ct in cotangents_in):
|
||||
return map(lambda v: Zero(v.aval), jaxpr.invars)
|
||||
|
||||
@ -207,36 +210,40 @@ def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents
|
||||
map(write_primal, jaxpr.invars, primals_in)
|
||||
|
||||
ct_env: Dict[Any, Any] = {}
|
||||
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
# FIXME: Some invars correspond to tangents
|
||||
invals = map(read_primal, eqn.invars)
|
||||
if eqn.primitive.multiple_results:
|
||||
cts_in = map(read_cotangent, eqn.outvars)
|
||||
else:
|
||||
cts_in, = map(read_cotangent, eqn.outvars)
|
||||
with source_info_util.user_context(eqn.source_info.traceback):
|
||||
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
|
||||
cts_in_avals = [v.aval for v in eqn.outvars]
|
||||
params = dict(eqn.params)
|
||||
call_jaxpr = params.pop('call_jaxpr')
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
|
||||
elif eqn.primitive in reducing_transposes:
|
||||
cts_out = reducing_transposes[eqn.primitive](
|
||||
reduce_axes, cts_in, *invals, **eqn.params)
|
||||
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
|
||||
else contextlib.nullcontext())
|
||||
with ctx:
|
||||
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
# FIXME: Some invars correspond to tangents
|
||||
invals = map(read_primal, eqn.invars)
|
||||
if eqn.primitive.multiple_results:
|
||||
cts_in = map(read_cotangent, eqn.outvars)
|
||||
else:
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
|
||||
**eqn.params)
|
||||
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
|
||||
# FIXME: Some invars correspond to primals!
|
||||
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
|
||||
cts_in, = map(read_cotangent, eqn.outvars)
|
||||
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
|
||||
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
|
||||
cts_in_avals = [v.aval for v in eqn.outvars]
|
||||
params = dict(eqn.params)
|
||||
call_jaxpr = params.pop('call_jaxpr')
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
|
||||
elif eqn.primitive in reducing_transposes:
|
||||
cts_out = reducing_transposes[eqn.primitive](
|
||||
reduce_axes, cts_in, *invals, **eqn.params)
|
||||
else:
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
cts_in, *invals, **eqn.params)
|
||||
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
|
||||
# FIXME: Some invars correspond to primals!
|
||||
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
|
||||
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
return cotangents_out
|
||||
|
||||
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, primals_in, cotangents_in):
|
||||
return backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals_in, cotangents_in)
|
||||
def closed_backward_pass(jaxpr: core.ClosedJaxpr, reduce_axes, transform_stack, primals_in, cotangents_in):
|
||||
return backward_pass(jaxpr.jaxpr, reduce_axes, transform_stack, jaxpr.consts, primals_in, cotangents_in)
|
||||
|
||||
|
||||
class UndefinedPrimal:
|
||||
@ -297,7 +304,7 @@ class JVPTrace(Trace):
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
|
||||
nz_tangents = [type(t) is not Zero for t in tangents]
|
||||
if 'name' in params:
|
||||
if 'name' in params and not config.jax_experimental_name_stack:
|
||||
params = dict(params, name=wrap_name(params['name'], 'jvp'))
|
||||
f_jvp = jvp_subtrace(f, self.main)
|
||||
f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
|
||||
@ -547,9 +554,12 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
||||
|
||||
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
if config.jax_experimental_name_stack:
|
||||
new_params = params
|
||||
else:
|
||||
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, map(is_undefined_primal, args),
|
||||
@ -575,7 +585,7 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
|
||||
residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
|
||||
# Now that we have a purely linear jaxpr, we can transpose it
|
||||
cotangents_out = backward_pass(
|
||||
tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in)
|
||||
tangent_jaxpr.jaxpr, reduce_axes, False, (), primals_in + residuals, cotangents_in)
|
||||
# backward_pass will return cotangents computed for all invars, but some of them
|
||||
# are residuals appended by partial eval, so we need to skip those before we return.
|
||||
return cotangents_out[:len(primals_in)]
|
||||
@ -594,7 +604,7 @@ def nonzero_outputs(*args, **kwargs):
|
||||
|
||||
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
|
||||
fun, nz_arg_cts = nonzero_outputs(fun)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
||||
@ -642,7 +652,8 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
assert len(jaxpr.in_avals) == len(nonzeros)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
|
||||
nonzeros)
|
||||
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
|
@ -22,6 +22,7 @@ import jax
|
||||
from jax.config import config
|
||||
from jax import core
|
||||
from jax.core import raise_to_shaped, Trace, Tracer
|
||||
from jax._src import source_info_util
|
||||
from jax._src.tree_util import tree_unflatten, tree_flatten
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, Zero)
|
||||
@ -29,7 +30,6 @@ from jax import linear_util as lu
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name,
|
||||
split_list, canonicalize_axis, moveaxis,
|
||||
as_hashable_function, curry, memoize, cache)
|
||||
from jax._src import source_info_util
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
map = safe_map
|
||||
@ -205,7 +205,10 @@ class BatchTrace(Trace):
|
||||
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
|
||||
if config.jax_experimental_name_stack:
|
||||
params = dict(params, name=params.get('name', f.__name__))
|
||||
else:
|
||||
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
if all(bdim is not_mapped for bdim in dims):
|
||||
return call_primitive.bind(f, *vals, **params)
|
||||
@ -372,7 +375,8 @@ def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size,
|
||||
def _batch_outer(axis_name, axis_size, in_dims, main_type, *in_vals):
|
||||
with core.new_main(main_type, axis_name=axis_name) as main:
|
||||
with core.extend_axis_env(axis_name, axis_size, main):
|
||||
outs = yield (main, in_dims, *in_vals), {}
|
||||
with source_info_util.transform_name_stack('vmap'):
|
||||
outs = yield (main, in_dims, *in_vals), {}
|
||||
del main
|
||||
yield outs
|
||||
|
||||
|
@ -266,8 +266,8 @@ register_constant_handler(
|
||||
def _source_info_to_location(
|
||||
primitive: core.Primitive, params: Dict,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
name_stack: str = "") -> ir.Location:
|
||||
eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
|
||||
name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location:
|
||||
eqn_str = str(name_stack) + core.str_eqn_compact(primitive.name, params)
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
if frame is None:
|
||||
loc = ir.Location.unknown()
|
||||
@ -280,6 +280,7 @@ def _source_info_to_location(
|
||||
|
||||
|
||||
# Translation rules
|
||||
NameStack = Union[str, source_info_util.NameStack]
|
||||
|
||||
def make_ir_context() -> ir.Context:
|
||||
"""Creates an MLIR context suitable for JAX IR."""
|
||||
@ -334,7 +335,7 @@ class ModuleContext:
|
||||
symbol_table: ir.SymbolTable
|
||||
platform: str
|
||||
axis_context: AxisContext
|
||||
name_stack: str
|
||||
name_stack: NameStack
|
||||
|
||||
# Cached primitive lowerings.
|
||||
cached_primitive_lowerings: Dict[Any, builtin.FuncOp]
|
||||
@ -344,7 +345,7 @@ class ModuleContext:
|
||||
return self.axis_context.axis_env
|
||||
|
||||
def __init__(
|
||||
self, platform: str, axis_context: AxisContext, name_stack: str,
|
||||
self, platform: str, axis_context: AxisContext, name_stack: NameStack,
|
||||
context: Optional[ir.Context] = None,
|
||||
module: Optional[ir.Module] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
@ -412,7 +413,7 @@ _module_name_regex = re.compile(r"[^\w.-]")
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str, jaxpr: core.ClosedJaxpr, platform: str,
|
||||
axis_context: AxisContext,
|
||||
name_stack: str, donated_args: Sequence[bool],
|
||||
name_stack: NameStack, donated_args: Sequence[bool],
|
||||
replicated_args: Optional[Sequence[bool]] = None,
|
||||
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
|
||||
|
@ -97,6 +97,11 @@ class PartialVal(tuple):
|
||||
|
||||
|
||||
class JaxprTrace(Trace):
|
||||
|
||||
def __init__(self, *args, name_stack: source_info_util.NameStack):
|
||||
super().__init__(*args)
|
||||
self.name_stack = name_stack
|
||||
|
||||
def pure(self, val) -> 'JaxprTracer':
|
||||
return self.new_const(val)
|
||||
|
||||
@ -163,7 +168,8 @@ class JaxprTrace(Trace):
|
||||
tracers = map(self.instantiate_const, tracers)
|
||||
avals = [t.aval for t in tracers]
|
||||
out_aval = primitive.abstract_eval(*avals, **params)
|
||||
source = source_info_util.current()
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
if primitive.multiple_results:
|
||||
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
|
||||
for aval in out_aval]
|
||||
@ -213,11 +219,11 @@ class JaxprTrace(Trace):
|
||||
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
|
||||
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
|
||||
out_tracers, primitive, staged_params,
|
||||
source_info_util.current())
|
||||
out_tracers, primitive, staged_params, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
def process_map(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
@ -305,8 +311,9 @@ class JaxprTrace(Trace):
|
||||
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
|
||||
new_params = update_params(params, [], len(in_tracers))
|
||||
new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
|
||||
source_info_util.current())
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
@ -342,8 +349,10 @@ class JaxprTrace(Trace):
|
||||
for d, a in zip(staged_out_axes, out_avals_mapped)]
|
||||
out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
|
||||
primitive, staged_params, source_info_util.current())
|
||||
primitive, staged_params, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
@ -355,6 +364,9 @@ class JaxprTrace(Trace):
|
||||
|
||||
return out, (todo, out_axes_transform)
|
||||
|
||||
def _current_truncated_name_stack(self):
|
||||
return source_info_util.current_name_stack()[len(self.name_stack):]
|
||||
|
||||
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
|
||||
instantiate: bool):
|
||||
@ -394,11 +406,13 @@ class JaxprTrace(Trace):
|
||||
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
||||
return converted_jaxpr, (*consts, *env)
|
||||
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
||||
num_consts=len(consts) + len(env)),
|
||||
source_info_util.current())
|
||||
source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
@ -434,12 +448,14 @@ class JaxprTrace(Trace):
|
||||
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
||||
return converted_jaxpr, (*consts, *env)
|
||||
|
||||
name_stack = self._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
||||
num_consts=len(consts) + len(env),
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
source_info_util.current())
|
||||
source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
@ -551,7 +567,8 @@ def trace_to_jaxpr(
|
||||
returned jaxpr takes as inputs the known residual values followed by values
|
||||
of the originally unknown inputs.
|
||||
"""
|
||||
with core.new_main(JaxprTrace) as main:
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
|
||||
fun = trace_to_subjaxpr(fun, main, instantiate)
|
||||
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
@ -565,7 +582,7 @@ def trace_to_subjaxpr_nounits(
|
||||
main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
|
||||
trace = JaxprTrace(main, core.cur_sublevel())
|
||||
trace = main.with_cur_sublevel()
|
||||
in_knowns = [pval.is_known() for pval in in_pvals]
|
||||
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
|
||||
@ -1464,11 +1481,13 @@ class DynamicJaxprTrace(core.Trace):
|
||||
dim_tracers = _get_tracers_only_in_shapes(tracers)
|
||||
in_avals = _tracers_to_avals(dim_tracers + tracers)
|
||||
keep_inputs = [False] * len(dim_tracers) + [True] * len(tracers)
|
||||
name_stack = source_info_util.current_name_stack()
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, in_avals, keep_inputs=keep_inputs)
|
||||
if params.get('inline', False):
|
||||
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
|
||||
with source_info_util.set_name_stack(name_stack):
|
||||
return core.eval_jaxpr(jaxpr, consts, *dim_tracers, *tracers)
|
||||
source_info = source_info_util.current()
|
||||
env = {v: t for v, t in zip((*jaxpr.constvars, *jaxpr.invars),
|
||||
(*consts, *dim_tracers, *tracers))
|
||||
@ -1695,7 +1714,7 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
||||
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
||||
|
||||
frame = JaxprStackFrame()
|
||||
with extend_jaxpr_stack(main, frame):
|
||||
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
|
||||
trace = DynamicJaxprTrace(main, core.cur_sublevel())
|
||||
in_tracers = _avals_to_tracers(trace, in_avals)
|
||||
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
|
||||
@ -1835,7 +1854,7 @@ def _get_tracers_in_shapes(seen: Set[TracerId], in_tracers: Sequence[Tracer]
|
||||
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
|
||||
pvals: Sequence[PartialVal]):
|
||||
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
||||
trace = JaxprTrace(main, core.cur_sublevel())
|
||||
trace = main.with_cur_sublevel()
|
||||
in_tracers = map(trace.new_arg, pvals)
|
||||
ans = yield in_tracers, {}
|
||||
assert isinstance(ans, (list, tuple)), (
|
||||
|
@ -54,7 +54,7 @@ from jax._src import device_array
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
extend_name_stack, new_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log)
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src import dispatch
|
||||
@ -1038,7 +1038,7 @@ def lower_parallel_callable(
|
||||
|
||||
axis_env = xla.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
name_stack = extend_name_stack(wrap_name(name, 'pmap'))
|
||||
name_stack = new_name_stack(wrap_name(name, 'pmap'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
module: Union[str, xc.XlaComputation]
|
||||
@ -2145,7 +2145,7 @@ def lower_mesh_computation(
|
||||
in_is_gda: Sequence[bool]):
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
name_stack = extend_name_stack(wrap_name(fun_name, api_name))
|
||||
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
global_axis_sizes = mesh.shape
|
||||
|
||||
@ -2236,7 +2236,7 @@ def lower_mesh_computation(
|
||||
partitions_are_protos=partitions_proto)
|
||||
|
||||
return MeshComputation(
|
||||
name_stack, module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
|
||||
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes,
|
||||
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_gda=in_is_gda)
|
||||
|
||||
|
@ -35,7 +35,7 @@ from jax._src.api_util import (argnums_partial, flatten_axes, flatten_fun,
|
||||
_ensure_index_tuple)
|
||||
import jax._src.util as util
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src.util import (extend_name_stack, wrap_name, wraps, safe_map,
|
||||
from jax._src.util import (new_name_stack, wrap_name, wraps, safe_map,
|
||||
safe_zip, HashableFunction)
|
||||
from jax._src.config import config
|
||||
|
||||
@ -149,7 +149,7 @@ def _sharded_callable(
|
||||
xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
|
||||
axis_env = xla.AxisEnv(nrep, (), ())
|
||||
ctx = xla.TranslationContext(
|
||||
c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit")))
|
||||
c, platform, axis_env, new_name_stack(wrap_name(name, "sharded_jit")))
|
||||
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
|
||||
out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
|
||||
built = c.Build(out_tuple)
|
||||
@ -202,7 +202,7 @@ def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
|
||||
|
||||
sub_ctx = ctx.replace(
|
||||
builder=subc,
|
||||
name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
|
||||
name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
|
||||
out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
|
||||
out_parts = out_parts_thunk()
|
||||
assert len(out_parts) == len(out_nodes)
|
||||
@ -234,7 +234,7 @@ def _sharded_jit_lowering(ctx, *in_nodes,
|
||||
args.append(ns)
|
||||
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(wrap_name(name, "sharded_jit")))
|
||||
name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
|
||||
fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}",
|
||||
core.ClosedJaxpr(call_jaxpr, ()))
|
||||
|
||||
|
@ -42,7 +42,7 @@ from jax.core import (ConcreteArray, ShapedArray,
|
||||
Literal, str_eqn_compact, abstract_token)
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src import util
|
||||
from jax._src.util import (prod, extend_name_stack, wrap_name,
|
||||
from jax._src.util import (prod, extend_name_stack, new_name_stack, wrap_name,
|
||||
safe_zip, safe_map, partition_list)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import partial_eval as pe
|
||||
@ -101,11 +101,15 @@ tracebacks = {}
|
||||
def make_op_metadata(primitive: core.Primitive,
|
||||
params: Dict, *,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
name_stack: str = "",
|
||||
name_stack: Union[str, source_info_util.NameStack] = "",
|
||||
) -> xc.OpMetadata:
|
||||
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
|
||||
if config.jax_experimental_name_stack:
|
||||
eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
|
||||
else:
|
||||
assert isinstance(name_stack, str)
|
||||
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
|
||||
tracebacks[eqn_str] = source_info.traceback
|
||||
frame = source_info_util.user_frame(source_info) if source_info else None
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
return xc.OpMetadata(
|
||||
op_type=primitive.name,
|
||||
op_name=eqn_str,
|
||||
@ -438,7 +442,7 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
|
||||
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False,
|
||||
filter_tokens=False)
|
||||
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
|
||||
name_stack="")
|
||||
name_stack=new_name_stack())
|
||||
ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params)
|
||||
if prim.multiple_results:
|
||||
ans = xops.Tuple(c, ans)
|
||||
@ -551,7 +555,7 @@ class TranslationContext:
|
||||
# with a specific platform in mind.
|
||||
platform: Optional[str]
|
||||
axis_env: AxisEnv
|
||||
name_stack: str
|
||||
name_stack: Union[str, source_info_util.NameStack]
|
||||
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
@ -581,9 +585,15 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
|
||||
_partitionmap(write, jaxpr.constvars, consts)
|
||||
_partitionmap(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
if config.jax_experimental_name_stack:
|
||||
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=ctx.name_stack + eqn.source_info.name_stack)
|
||||
else:
|
||||
source_info = eqn.source_info
|
||||
op_metadata = make_op_metadata(
|
||||
eqn.primitive, eqn.params, name_stack=ctx.name_stack,
|
||||
source_info=eqn.source_info)
|
||||
source_info=source_info)
|
||||
ctx.builder.set_op_metadata(op_metadata)
|
||||
in_nodes = _flatmap(read, eqn.invars)
|
||||
if (ctx.platform is not None and
|
||||
@ -596,7 +606,9 @@ def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
|
||||
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
|
||||
|
||||
with source_info_util.user_context(eqn.source_info.traceback):
|
||||
ans = rule(ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
|
||||
eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
|
||||
config.jax_experimental_name_stack else ctx)
|
||||
ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
|
||||
*in_nodes, **eqn.params)
|
||||
|
||||
assert isinstance(ans, collections.abc.Sequence), (ans, eqn)
|
||||
@ -755,8 +767,8 @@ def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
|
||||
@profiler.annotate_function
|
||||
def lower_jaxpr_to_xla_module(
|
||||
fn_name: str, jaxpr: core.ClosedJaxpr, platform: str, axis_env: AxisEnv,
|
||||
name_stack: str, tuple_args: bool, donated_invars: Sequence[bool],
|
||||
replicated_args: Optional[Sequence[bool]],
|
||||
name_stack: Union[source_info_util.NameStack, str], tuple_args: bool,
|
||||
donated_invars: Sequence[bool], replicated_args: Optional[Sequence[bool]],
|
||||
arg_partitions: Optional[Any],
|
||||
out_partitions: Optional[Any],
|
||||
partitions_are_protos: bool = False
|
||||
@ -1042,7 +1054,7 @@ def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
|
||||
wrapped_fun = _tuple_output(wrapped_fun)
|
||||
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
ctx = TranslationContext(c, backend, axis_env, '')
|
||||
ctx = TranslationContext(c, backend, axis_env, new_name_stack())
|
||||
outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
|
||||
if (multiple_results or
|
||||
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):
|
||||
|
@ -7311,7 +7311,13 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
return my_test_function(x)
|
||||
|
||||
c = jax.xla_computation(f)(2)
|
||||
self.assertIn("my_test_function", c.as_hlo_text())
|
||||
if config.jax_experimental_name_stack:
|
||||
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
|
||||
print_opts.print_metadata = True
|
||||
hlo_text = c.as_hlo_module().to_string(print_opts)
|
||||
else:
|
||||
hlo_text = c.as_hlo_text()
|
||||
self.assertIn("my_test_function", hlo_text)
|
||||
|
||||
def test_non_jaxtype_arg(self):
|
||||
# For the test to fail without the invalid JaxType filter we need to pass
|
||||
|
612
tests/name_stack_test.py
Normal file
612
tests/name_stack_test.py
Normal file
@ -0,0 +1,612 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import source_info_util
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
extend_name_stack = source_info_util.extend_name_stack
|
||||
|
||||
def _get_hlo(f):
|
||||
def wrapped(*args, **kwargs):
|
||||
c = jax.xla_computation(f)(*args, **kwargs)
|
||||
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
|
||||
print_opts.print_metadata = True
|
||||
return c.as_hlo_module().to_string(print_opts)
|
||||
return wrapped
|
||||
|
||||
class _EnableNameStackTestCase(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = config._read("jax_experimental_name_stack")
|
||||
config.update("jax_experimental_name_stack", True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_experimental_name_stack", self.cfg)
|
||||
|
||||
|
||||
class NameStackTest(_EnableNameStackTestCase):
|
||||
|
||||
def test_trivial_name_stack(self):
|
||||
|
||||
def f(x):
|
||||
return x + 1
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
for eqn in jaxpr.eqns:
|
||||
self.assertEqual(str(eqn.source_info.name_stack), '')
|
||||
|
||||
def test_name_call_name_stack(self):
|
||||
|
||||
@jax.named_call
|
||||
def f(x):
|
||||
return x + 1
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
for eqn in jaxpr.eqns:
|
||||
self.assertEqual(str(eqn.source_info.name_stack), 'f')
|
||||
|
||||
def test_manual_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
return x + 1
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
for eqn in jaxpr.eqns:
|
||||
self.assertEqual(str(eqn.source_info.name_stack), 'foo')
|
||||
|
||||
def test_nested_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
with extend_name_stack('bar'):
|
||||
return x + 1
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
for eqn in jaxpr.eqns:
|
||||
self.assertEqual(str(eqn.source_info.name_stack), 'foo/bar')
|
||||
|
||||
def test_multiple_name_stack(self):
|
||||
|
||||
def f(x):
|
||||
with extend_name_stack('foo'):
|
||||
y = x + 1
|
||||
with extend_name_stack('bar'):
|
||||
with extend_name_stack('baz'):
|
||||
return y + 1
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'bar/baz')
|
||||
|
||||
def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@lu.wrap_init
|
||||
@extend_name_stack('bar')
|
||||
def _f(x):
|
||||
return [x + 1]
|
||||
return core.call(_f, x)[0]
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
|
||||
hlo_text = _get_hlo(f)(2)
|
||||
self.assertIn('foo/core_call/bar', hlo_text)
|
||||
|
||||
def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@jax.jit
|
||||
@extend_name_stack('bar')
|
||||
def _f(x):
|
||||
return x + 1
|
||||
return _f(x)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(2).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
|
||||
hlo_text = _get_hlo(f)(2)
|
||||
self.assertIn('foo/jit(_f)/bar', hlo_text)
|
||||
|
||||
def test_pmap_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
||||
@extend_name_stack('foo')
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
with extend_name_stack('bar'):
|
||||
return x + 1
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.ones(1)).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
|
||||
|
||||
class NameStackTransformationTest(_EnableNameStackTestCase):
|
||||
|
||||
def test_vmap_should_transform_name_stack(self):
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
with extend_name_stack('foo'):
|
||||
return x + 1
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
|
||||
|
||||
def test_vmap_should_transform_inner_name_stacks(self):
|
||||
@extend_name_stack('foo')
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
with extend_name_stack('bar'):
|
||||
with extend_name_stack('baz'):
|
||||
return x + 1
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/vmap(bar)/vmap(baz)')
|
||||
|
||||
def test_vmap_should_apply_to_call_jaxpr(self):
|
||||
@extend_name_stack('foo')
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
@jax.jit
|
||||
@extend_name_stack('bar')
|
||||
def _f(x):
|
||||
return x + 1
|
||||
return _f(x)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.ones(2)).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
|
||||
hlo_text = _get_hlo(f)(jnp.ones(2))
|
||||
self.assertIn('foo/vmap(jit(_f))/vmap(bar)', hlo_text)
|
||||
|
||||
def test_jvp_should_transform_stacks(self):
|
||||
def f(x):
|
||||
with extend_name_stack('bar'):
|
||||
with extend_name_stack('baz'):
|
||||
return jnp.square(x)
|
||||
g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
|
||||
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
|
||||
'foo/jvp(bar)/jvp(baz)')
|
||||
|
||||
def test_jvp_should_apply_to_call_jaxpr(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
with extend_name_stack('bar'):
|
||||
with extend_name_stack('baz'):
|
||||
return jnp.square(x)
|
||||
g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
|
||||
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(
|
||||
str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar/baz')
|
||||
|
||||
hlo_text = _get_hlo(g)(1., 1.)
|
||||
self.assertIn('foo/jvp(jit(f))/jvp(bar)', hlo_text)
|
||||
|
||||
def test_grad_should_add_jvp_and_transpose_to_name_stack(self):
|
||||
@jax.grad
|
||||
def f(x):
|
||||
with extend_name_stack('foo'):
|
||||
return jnp.sin(x)
|
||||
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(jaxpr.eqns[2].source_info.name_stack),
|
||||
'transpose(jvp(foo))')
|
||||
|
||||
hlo_text = _get_hlo(f)(1.)
|
||||
self.assertIn('jvp(foo)/sin', hlo_text)
|
||||
self.assertIn('jvp(foo)/cos', hlo_text)
|
||||
self.assertIn('transpose(jvp(foo))/mul', hlo_text)
|
||||
|
||||
def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self):
|
||||
@jax.grad
|
||||
@extend_name_stack('foo')
|
||||
@jax.jit
|
||||
def f(x):
|
||||
with extend_name_stack('bar'):
|
||||
return jnp.sin(x)
|
||||
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'transpose(jvp(foo))')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['call_jaxpr'].eqns[1].source_info.name_stack), 'bar')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[1].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')
|
||||
|
||||
hlo_text = _get_hlo(f)(1.)
|
||||
self.assertIn('jvp(foo)/jvp(jit(f))/jvp(bar)/sin', hlo_text)
|
||||
self.assertIn('jvp(foo)/jvp(jit(f))/jvp(bar)/cos', hlo_text)
|
||||
self.assertIn(
|
||||
'transpose(jvp(foo))/transpose(jvp(jit(f)))/transpose(jvp(bar))/mul',
|
||||
hlo_text)
|
||||
|
||||
|
||||
class NameStackControlFlowTest(_EnableNameStackTestCase):
|
||||
|
||||
def test_while_loop_body_should_not_have_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('bar')
|
||||
def body(x):
|
||||
return x + 1
|
||||
@extend_name_stack('bar_cond')
|
||||
def cond(x):
|
||||
return x < 5
|
||||
return lax.while_loop(cond, body, x)
|
||||
jaxpr = jax.make_jaxpr(f)(0)
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar_cond')
|
||||
|
||||
hlo_text = _get_hlo(f)(1.)
|
||||
self.assertIn('foo/while/body/bar', hlo_text)
|
||||
self.assertIn('foo/while/cond/bar_cond', hlo_text)
|
||||
|
||||
def test_vmap_of_while_loop_should_transform_name_stack(self):
|
||||
|
||||
@jax.vmap
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('bar')
|
||||
def body(x):
|
||||
return x + 1
|
||||
@extend_name_stack('bar_cond')
|
||||
def cond(x):
|
||||
return x < 5
|
||||
return lax.while_loop(cond, body, x)
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(2))
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar_cond')
|
||||
|
||||
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
||||
self.assertIn('vmap(foo)/vmap(while)/vmap(body)/vmap(bar)', hlo_text)
|
||||
self.assertIn('vmap(foo)/vmap(while)/vmap(cond)/vmap(bar_cond)', hlo_text)
|
||||
|
||||
def test_jvp_of_while_loop_transforms_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('bar')
|
||||
def body(x):
|
||||
return x + 1.
|
||||
@extend_name_stack('bar_cond')
|
||||
def cond(x):
|
||||
return x < 5.
|
||||
return lax.while_loop(cond, body, x)
|
||||
g = lambda x, t: jax.jvp(f, (x,), (t,))
|
||||
jaxpr = jax.make_jaxpr(g)(1., 1.)
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar_cond')
|
||||
|
||||
hlo_text = _get_hlo(g)(1., 1.)
|
||||
self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(bar)', hlo_text)
|
||||
self.assertIn('jvp(foo)/jvp(while)/jvp(cond)/jvp(bar_cond)', hlo_text)
|
||||
|
||||
def test_vmap_of_jvp_of_while_loop_transforms_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('bar')
|
||||
def body(x):
|
||||
return x + 1.
|
||||
@extend_name_stack('bar_cond')
|
||||
def cond(x):
|
||||
return x < 5.
|
||||
return lax.while_loop(cond, body, x)
|
||||
g = jax.vmap(lambda x, t: jax.jvp(f, (x,), (t,)))
|
||||
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['body_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['cond_jaxpr'].eqns[0].source_info.name_stack),
|
||||
'bar_cond')
|
||||
|
||||
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
|
||||
self.assertIn(
|
||||
'vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(body))/vmap(jvp(bar))',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(cond))/vmap(jvp(bar_cond))',
|
||||
hlo_text)
|
||||
|
||||
def test_cond_body_should_not_have_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('true')
|
||||
def true_fn(x):
|
||||
return x + 1
|
||||
@extend_name_stack('false')
|
||||
def false_fn(x):
|
||||
return x - 1
|
||||
return lax.cond(True, true_fn, false_fn, x)
|
||||
jaxpr = jax.make_jaxpr(f)(0)
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
|
||||
'false')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
|
||||
'true')
|
||||
|
||||
hlo_text = _get_hlo(f)(1.)
|
||||
self.assertIn('foo/cond/branch_0_fun/false/sub', hlo_text)
|
||||
self.assertIn('foo/cond/branch_1_fun/true/add', hlo_text)
|
||||
|
||||
def test_vmap_of_cond_should_transform_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
@extend_name_stack('true')
|
||||
def true_fn(x):
|
||||
return x + 1
|
||||
@extend_name_stack('false')
|
||||
def false_fn(x):
|
||||
return x - 1
|
||||
return lax.cond(True, true_fn, false_fn, x)
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(2))
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
|
||||
'false')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
|
||||
'true')
|
||||
|
||||
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
||||
self.assertIn('foo/vmap(cond)/vmap(branch_0_fun)/vmap(false)/sub', hlo_text)
|
||||
self.assertIn('foo/vmap(cond)/vmap(branch_1_fun)/vmap(true)/add', hlo_text)
|
||||
|
||||
def test_jvp_of_cond_transforms_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('true')
|
||||
def true_fn(x):
|
||||
return x + 1
|
||||
@extend_name_stack('false')
|
||||
def false_fn(x):
|
||||
return x - 1
|
||||
return lax.cond(True, true_fn, false_fn, x)
|
||||
g = lambda x, t: jax.jvp(f, (x,), (t,))
|
||||
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
|
||||
'false')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
|
||||
'true')
|
||||
|
||||
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
|
||||
self.assertIn('jvp(foo)/jvp(cond)/jvp(branch_0_fun)/jvp(false)/sub', hlo_text)
|
||||
self.assertIn('jvp(foo)/jvp(cond)/jvp(branch_1_fun)/jvp(true)/add', hlo_text)
|
||||
|
||||
def test_vmap_of_jvp_of_cond_transforms_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('true')
|
||||
def true_fn(x):
|
||||
return x + 1
|
||||
@extend_name_stack('false')
|
||||
def false_fn(x):
|
||||
return x - 1
|
||||
return lax.cond(True, true_fn, false_fn, x)
|
||||
g = jax.vmap(lambda x, t: jax.jvp(f, (x,), (t,)))
|
||||
jaxpr = jax.make_jaxpr(g)(jnp.arange(2.), jnp.ones(2))
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][0].eqns[0].source_info.name_stack),
|
||||
'false')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[0].params['branches'][1].eqns[0].source_info.name_stack),
|
||||
'true')
|
||||
|
||||
hlo_text = _get_hlo(g)(jnp.arange(2.), jnp.ones(2))
|
||||
self.assertIn(
|
||||
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_0_fun))/vmap(jvp(false))/sub',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_1_fun))/vmap(jvp(true))/add',
|
||||
hlo_text)
|
||||
|
||||
def test_grad_of_cond_transforms_name_stack(self):
|
||||
|
||||
@jax.grad
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('true')
|
||||
def true_fn(x):
|
||||
return x * 2.
|
||||
@extend_name_stack('false')
|
||||
def false_fn(x):
|
||||
return x / 2.
|
||||
return lax.cond(True, true_fn, false_fn, x)
|
||||
jaxpr = jax.make_jaxpr(f)(1.)
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack),
|
||||
'transpose(jvp(foo))')
|
||||
|
||||
hlo_text = _get_hlo(f)(1.)
|
||||
self.assertIn(
|
||||
'jvp(foo)/jvp(cond)/jvp(branch_0_fun)/jvp(false)/div',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'jvp(foo)/jvp(cond)/jvp(branch_1_fun)/jvp(true)/mul',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'transpose(jvp(foo))/transpose(jvp(cond))/transpose(jvp(branch_0_fun))/transpose(jvp(false))/div',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'transpose(jvp(foo))/transpose(jvp(cond))/transpose(jvp(branch_1_fun))/transpose(jvp(true))/mul',
|
||||
hlo_text)
|
||||
|
||||
def test_vmap_of_grad_of_cond_transforms_name_stack(self):
|
||||
|
||||
@jax.vmap
|
||||
@jax.grad
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('true')
|
||||
def true_fn(x):
|
||||
return x * 2.
|
||||
@extend_name_stack('false')
|
||||
def false_fn(x):
|
||||
return x / 2.
|
||||
return lax.cond(True, true_fn, false_fn, x)
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(jvp(foo))')
|
||||
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack),
|
||||
'vmap(transpose(jvp(foo)))')
|
||||
|
||||
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
||||
self.assertIn(
|
||||
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_0_fun))/vmap(jvp(false))/div',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'vmap(jvp(foo))/vmap(jvp(cond))/vmap(jvp(branch_1_fun))/vmap(jvp(true))/mul',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'vmap(transpose(jvp(foo)))/vmap(transpose(jvp(cond)))/vmap(transpose(jvp(branch_0_fun)))/vmap(transpose(jvp(false)))/div',
|
||||
hlo_text)
|
||||
self.assertIn(
|
||||
'vmap(transpose(jvp(foo)))/vmap(transpose(jvp(cond)))/vmap(transpose(jvp(branch_1_fun)))/vmap(transpose(jvp(true)))/mul',
|
||||
hlo_text)
|
||||
|
||||
def test_scan_body_should_not_have_name_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('scan_body')
|
||||
def body(carry, x):
|
||||
return carry + x, carry + x
|
||||
return lax.scan(body, x, jnp.arange(5.))
|
||||
jaxpr = jax.make_jaxpr(f)(1.)
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
||||
'scan_body')
|
||||
|
||||
hlo_text = _get_hlo(f)(1.)
|
||||
self.assertIn('foo/while/body/scan_body', hlo_text)
|
||||
|
||||
def test_vmap_of_scan_should_transform_stack(self):
|
||||
|
||||
@jax.vmap
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('scan_body')
|
||||
def body(carry, x):
|
||||
return carry + x, carry + x
|
||||
return lax.scan(body, x, jnp.arange(8.))
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'vmap(foo)')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
||||
'scan_body')
|
||||
|
||||
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
||||
self.assertIn('vmap(foo)/vmap(while)/vmap(body)/vmap(scan_body)/add', hlo_text)
|
||||
|
||||
def test_jvp_of_scan_should_transform_stack(self):
|
||||
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('scan_body')
|
||||
def body(carry, x):
|
||||
return carry + x, carry + x
|
||||
return lax.scan(body, x, jnp.arange(8.))
|
||||
g = lambda x, t: jax.jvp(f, (x,), (t,))
|
||||
jaxpr = jax.make_jaxpr(g)(1., 1.)
|
||||
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
||||
'scan_body')
|
||||
|
||||
hlo_text = _get_hlo(g)(1., 1.)
|
||||
self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(scan_body)/add', hlo_text)
|
||||
|
||||
def test_grad_of_scan_should_transform_stack(self):
|
||||
|
||||
@jax.grad
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('scan_body')
|
||||
def body(carry, x):
|
||||
return carry * x, carry + x
|
||||
return lax.scan(body, x, jnp.arange(8.))[0]
|
||||
jaxpr = jax.make_jaxpr(f)(1.)
|
||||
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'jvp(foo)')
|
||||
self.assertEqual(str(jaxpr.eqns[3].source_info.name_stack),
|
||||
'transpose(jvp(foo))')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
||||
'scan_body')
|
||||
|
||||
hlo_text = _get_hlo(f)(1.)
|
||||
self.assertIn('jvp(foo)/jvp(while)/jvp(body)/jvp(scan_body)/mul', hlo_text)
|
||||
self.assertIn('transpose(jvp(foo))/transpose(jvp(while))/transpose(jvp(body))/transpose(jvp(scan_body))/mul', hlo_text)
|
||||
|
||||
def test_vmap_of_grad_of_scan_should_transform_stack(self):
|
||||
|
||||
@jax.vmap
|
||||
@jax.grad
|
||||
@extend_name_stack('foo')
|
||||
def f(x):
|
||||
@extend_name_stack('scan_body')
|
||||
def body(carry, x):
|
||||
return carry * x, carry + x
|
||||
return lax.scan(body, x, jnp.arange(8.))[0]
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(2.))
|
||||
self.assertEqual(str(jaxpr.eqns[1].source_info.name_stack), 'vmap(jvp(foo))')
|
||||
self.assertEqual(str(jaxpr.eqns[3].source_info.name_stack),
|
||||
'vmap(transpose(jvp(foo)))')
|
||||
self.assertEqual(str(
|
||||
jaxpr.eqns[1].params['jaxpr'].eqns[0].source_info.name_stack),
|
||||
'scan_body')
|
||||
|
||||
hlo_text = _get_hlo(f)(jnp.arange(2.))
|
||||
self.assertIn('vmap(jvp(foo))/vmap(jvp(while))/vmap(jvp(body))/vmap(jvp(scan_body))/mul', hlo_text)
|
||||
self.assertIn('vmap(transpose(jvp(foo)))/vmap(transpose(jvp(while)))/vmap(transpose(jvp(body)))/vmap(transpose(jvp(scan_body)))/mul', hlo_text)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user