mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main' * Rename master->main in JAX API and internals (#4178) * Started with #4174 * Renamed Trace.master to Trace.main * Renamed core.new_master and core.new_base_master Co-authored-by: George Necula <gcnecula@gmail.com>
This commit is contained in:
parent
1a87fd3bc1
commit
6b6789a53b
@ -85,7 +85,7 @@ templates_path = ['_templates']
|
||||
source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
main_doc = 'index'
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
@ -256,7 +256,7 @@ latex_elements = {
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'JAX.tex', 'JAX Documentation',
|
||||
(main_doc, 'JAX.tex', 'JAX Documentation',
|
||||
'The JAX authors', 'manual'),
|
||||
]
|
||||
|
||||
@ -266,7 +266,7 @@ latex_documents = [
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'jax', 'JAX Documentation',
|
||||
(main_doc, 'jax', 'JAX Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
@ -277,7 +277,7 @@ man_pages = [
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'JAX', 'JAX Documentation',
|
||||
(main_doc, 'JAX', 'JAX Documentation',
|
||||
author, 'JAX', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
98
jax/core.py
98
jax/core.py
@ -356,15 +356,15 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
|
||||
|
||||
|
||||
class Trace:
|
||||
__slots__ = ['master', 'level', 'sublevel']
|
||||
__slots__ = ['main', 'level', 'sublevel']
|
||||
|
||||
master: 'MasterTrace'
|
||||
main: 'MainTrace'
|
||||
level: int
|
||||
sublevel: 'Sublevel'
|
||||
|
||||
def __init__(self, master: 'MasterTrace', sublevel: 'Sublevel') -> None:
|
||||
self.master = master
|
||||
self.level = master.level
|
||||
def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
|
||||
self.main = main
|
||||
self.level = main.level
|
||||
self.sublevel = sublevel
|
||||
|
||||
def full_raise(self, val) -> 'Tracer':
|
||||
@ -372,7 +372,7 @@ class Trace:
|
||||
return self.pure(val)
|
||||
level = self.level
|
||||
sublevel = self.sublevel
|
||||
if val._trace.master is self.master:
|
||||
if val._trace.main is self.main:
|
||||
if val._trace.sublevel == sublevel:
|
||||
return val
|
||||
elif val._trace.sublevel < sublevel:
|
||||
@ -589,7 +589,7 @@ class EvalTrace(Trace):
|
||||
process_map = process_call
|
||||
|
||||
|
||||
class MasterTrace:
|
||||
class MainTrace:
|
||||
level: int
|
||||
trace_type: Type[Trace]
|
||||
|
||||
@ -598,18 +598,18 @@ class MasterTrace:
|
||||
self.trace_type = trace_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "MasterTrace({},{})".format(self.level, self.trace_type.__name__)
|
||||
return "MainTrace({},{})".format(self.level, self.trace_type.__name__)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.level, self.trace_type))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (isinstance(other, MasterTrace) and
|
||||
return (isinstance(other, MainTrace) and
|
||||
self.level == other.level and self.trace_type == other.trace_type)
|
||||
|
||||
class TraceStack:
|
||||
upward: List[MasterTrace]
|
||||
downward: List[MasterTrace]
|
||||
upward: List[MainTrace]
|
||||
downward: List[MainTrace]
|
||||
|
||||
def __init__(self):
|
||||
self.upward = []
|
||||
@ -621,11 +621,11 @@ class TraceStack:
|
||||
else:
|
||||
return len(self.upward)
|
||||
|
||||
def push(self, master_trace: MasterTrace, bottom: bool) -> None:
|
||||
def push(self, main_trace: MainTrace, bottom: bool) -> None:
|
||||
if bottom:
|
||||
self.downward.append(master_trace)
|
||||
self.downward.append(main_trace)
|
||||
else:
|
||||
self.upward.append(master_trace)
|
||||
self.upward.append(main_trace)
|
||||
|
||||
def pop(self, bottom: bool) -> None:
|
||||
if bottom:
|
||||
@ -687,19 +687,19 @@ def cur_sublevel() -> Sublevel:
|
||||
return thread_local_state.trace_state.substack[-1]
|
||||
|
||||
@contextmanager
|
||||
def new_master(trace_type: Type[Trace], bottom=False) -> Generator[MasterTrace, None, None]:
|
||||
def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]:
|
||||
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
|
||||
master = MasterTrace(level, trace_type)
|
||||
thread_local_state.trace_state.trace_stack.push(master, bottom)
|
||||
main = MainTrace(level, trace_type)
|
||||
thread_local_state.trace_state.trace_stack.push(main, bottom)
|
||||
|
||||
try:
|
||||
yield master
|
||||
yield main
|
||||
finally:
|
||||
thread_local_state.trace_state.trace_stack.pop(bottom)
|
||||
|
||||
if check_leaks:
|
||||
t = ref(master)
|
||||
del master
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
print(thread_local_state.trace_state.trace_stack)
|
||||
raise Exception('Leaked trace {}'.format(t()))
|
||||
@ -728,7 +728,7 @@ def full_lower(val):
|
||||
def find_top_trace(xs) -> Optional[Trace]:
|
||||
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
|
||||
key=attrgetter('level'), default=None)
|
||||
return top_trace and type(top_trace)(top_trace.master, cur_sublevel())
|
||||
return top_trace and type(top_trace)(top_trace.main, cur_sublevel())
|
||||
|
||||
@contextmanager
|
||||
def initial_style_staging():
|
||||
@ -1116,7 +1116,7 @@ def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
ans = max(tracers, key=lambda x: x._trace.level)
|
||||
else:
|
||||
break
|
||||
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
|
||||
trace = type(ans._trace)(ans._trace.main, cur_sublevel())
|
||||
outs = map(trace.full_raise, outs)
|
||||
outs, cur_todo = primitive.post_process(trace, outs, params)
|
||||
todo.append(cur_todo)
|
||||
@ -1436,24 +1436,24 @@ axis_frame = None
|
||||
def omnistaging_enabler() -> None:
|
||||
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
|
||||
new_master, reset_trace_state, extend_axis_env, axis_frame, \
|
||||
new_base_master, eval_context, \
|
||||
new_base_main, eval_context, \
|
||||
TraceStack, TraceState
|
||||
del initial_style_staging
|
||||
|
||||
class TraceStack:
|
||||
stack: List[MasterTrace]
|
||||
dynamic: MasterTrace
|
||||
stack: List[MainTrace]
|
||||
dynamic: MainTrace
|
||||
|
||||
def __init__(self):
|
||||
eval_trace = MasterTrace(0, EvalTrace)
|
||||
eval_trace = MainTrace(0, EvalTrace)
|
||||
self.stack = [eval_trace]
|
||||
self.dynamic = eval_trace
|
||||
|
||||
def next_level(self) -> int:
|
||||
return len(self.stack)
|
||||
|
||||
def push(self, master_trace: MasterTrace) -> None:
|
||||
self.stack.append(master_trace)
|
||||
def push(self, main_trace: MainTrace) -> None:
|
||||
self.stack.append(main_trace)
|
||||
|
||||
def pop(self) -> None:
|
||||
self.stack.pop()
|
||||
@ -1491,8 +1491,8 @@ def omnistaging_enabler() -> None:
|
||||
"Reset the global trace state and return True if it was already clean."
|
||||
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
|
||||
thread_local_state.trace_state.axis_env != [] or
|
||||
thread_local_state.trace_state.trace_stack.stack != [MasterTrace(0, EvalTrace)] or
|
||||
thread_local_state.trace_state.trace_stack.dynamic != MasterTrace(0, EvalTrace)):
|
||||
thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or
|
||||
thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)):
|
||||
thread_local_state.trace_state.__init__() # type: ignore
|
||||
return False
|
||||
else:
|
||||
@ -1512,55 +1512,55 @@ def omnistaging_enabler() -> None:
|
||||
def maybe_new_sublevel(trace):
|
||||
# dynamic traces run the WrappedFun, so we raise the sublevel for them
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
return new_sublevel() if trace.master is dynamic else suppress()
|
||||
return new_sublevel() if trace.main is dynamic else suppress()
|
||||
|
||||
def find_top_trace(xs) -> Trace:
|
||||
top_master = max((x._trace.master for x in xs if isinstance(x, Tracer)),
|
||||
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
|
||||
default=None, key=attrgetter('level'))
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
top_master = (dynamic if top_master is None or dynamic.level > top_master.level
|
||||
else top_master)
|
||||
return top_master and top_master.trace_type(top_master, cur_sublevel()) # type: ignore
|
||||
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
|
||||
else top_main)
|
||||
return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore
|
||||
|
||||
@contextmanager
|
||||
def new_master(trace_type: Type[Trace], dynamic: bool = False,
|
||||
) -> Generator[MasterTrace, None, None]:
|
||||
def new_main(trace_type: Type[Trace], dynamic: bool = False,
|
||||
) -> Generator[MainTrace, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
level = stack.next_level()
|
||||
master = MasterTrace(level, trace_type)
|
||||
stack.push(master)
|
||||
main = MainTrace(level, trace_type)
|
||||
stack.push(main)
|
||||
if dynamic:
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, master
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, main
|
||||
|
||||
try:
|
||||
yield master
|
||||
yield main
|
||||
finally:
|
||||
thread_local_state.trace_state.trace_stack.pop()
|
||||
if dynamic:
|
||||
stack.dynamic = prev_dynamic
|
||||
|
||||
if check_leaks:
|
||||
t = ref(master)
|
||||
del master
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
print(thread_local_state.trace_state.trace_stack)
|
||||
raise Exception('Leaked trace {}'.format(t()))
|
||||
|
||||
@contextmanager
|
||||
def new_base_master(trace_type: Type[Trace]) -> Generator[MasterTrace, None, None]:
|
||||
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
master = MasterTrace(0, trace_type)
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, master
|
||||
prev_base, stack.stack[0] = stack.stack[0], master
|
||||
main = MainTrace(0, trace_type)
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, main
|
||||
prev_base, stack.stack[0] = stack.stack[0], main
|
||||
try:
|
||||
yield master
|
||||
yield main
|
||||
finally:
|
||||
stack.dynamic = prev_dynamic
|
||||
stack.stack[0] = prev_base
|
||||
|
||||
@contextmanager
|
||||
def eval_context():
|
||||
with new_base_master(EvalTrace):
|
||||
with new_base_main(EvalTrace):
|
||||
yield
|
||||
|
||||
def bind(self, *args, **params):
|
||||
|
@ -103,11 +103,11 @@ def callback_subtrace(master, *in_vals, **params):
|
||||
|
||||
@lu.transformation
|
||||
def _callback_fun(callback, strip_calls, *in_vals, **params):
|
||||
with core.new_master(CallbackTrace) as master:
|
||||
master.callback = callback # NOTE: Is this OK?
|
||||
master.strip_calls = strip_calls
|
||||
out_vals = yield (master,) + in_vals, params
|
||||
del master
|
||||
with core.new_main(CallbackTrace) as main:
|
||||
main.callback = callback # NOTE: Is this OK?
|
||||
main.strip_calls = strip_calls
|
||||
out_vals = yield (main,) + in_vals, params
|
||||
del main
|
||||
yield out_vals
|
||||
|
||||
def _check_callable(fun):
|
||||
@ -144,15 +144,15 @@ class CallbackTrace(Trace):
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
vals_in = [t.val for t in tracers]
|
||||
vals_out = self.master.callback(primitive, vals_in, params) # type: ignore
|
||||
vals_out = self.main.callback(primitive, vals_in, params) # type: ignore
|
||||
if primitive.multiple_results:
|
||||
return [CallbackTracer(self, val) for val in vals_out]
|
||||
return CallbackTracer(self, vals_out)
|
||||
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
||||
if self.master.strip_calls: # type: ignore
|
||||
if self.main.strip_calls: # type: ignore
|
||||
return f.call_wrapped(*tracers)
|
||||
vals_in = [t.val for t in tracers]
|
||||
f = callback_subtrace(f, self.master)
|
||||
f = callback_subtrace(f, self.main)
|
||||
vals_out = call_primitive.bind(f, *vals_in, **params)
|
||||
return [CallbackTracer(self, val) for val in vals_out]
|
||||
|
@ -76,7 +76,7 @@ class DoublingTrace(core.Trace):
|
||||
assert call_primitive.multiple_results
|
||||
heads, tails = unzip2((t.head, t.tail) for t in tracers)
|
||||
nonzero_tails, in_tree_def = tree_flatten(tails)
|
||||
f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.master),
|
||||
f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main),
|
||||
len(heads), in_tree_def)
|
||||
name = params.get('name', f.__name__)
|
||||
new_params = dict(params, name=wrap_name(name, 'doubledouble'),
|
||||
@ -109,7 +109,7 @@ def screen_nones(num_heads, in_tree_def, *heads_and_tails):
|
||||
|
||||
@lu.transformation
|
||||
def doubling_transform(*args):
|
||||
with core.new_master(DoublingTrace) as master:
|
||||
with core.new_main(DoublingTrace) as master:
|
||||
trace = DoublingTrace(master, core.cur_sublevel())
|
||||
in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args]
|
||||
outputs = yield in_tracers, {}
|
||||
|
@ -212,7 +212,7 @@ def convert(fun, with_gradient=False):
|
||||
|
||||
def _interpret_fun(fun: lu.WrappedFun,
|
||||
in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]:
|
||||
with core.new_master(TensorFlowTrace) as master:
|
||||
with core.new_main(TensorFlowTrace) as master:
|
||||
fun = _interpret_subtrace(fun, master)
|
||||
out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals)
|
||||
del master
|
||||
@ -220,7 +220,7 @@ def _interpret_fun(fun: lu.WrappedFun,
|
||||
|
||||
|
||||
@lu.transformation
|
||||
def _interpret_subtrace(master: core.MasterTrace, *in_vals: TfValOrUnit):
|
||||
def _interpret_subtrace(master: core.MainTrace, *in_vals: TfValOrUnit):
|
||||
trace = TensorFlowTrace(master, core.cur_sublevel())
|
||||
in_tracers = tuple(TensorFlowTracer(trace, val) for val in in_vals)
|
||||
outs = yield in_tracers, {} # type: Sequence[TfValOrUnit]
|
||||
@ -332,7 +332,7 @@ class TensorFlowTrace(core.Trace):
|
||||
tracers: Sequence[TensorFlowTracer], params):
|
||||
assert call_primitive.multiple_results
|
||||
vals: Sequence[TfValOrUnit] = [t.val for t in tracers]
|
||||
f = _interpret_subtrace(f, self.master)
|
||||
f = _interpret_subtrace(f, self.main)
|
||||
vals_out: Sequence[TfValOrUnit] = f.call_wrapped(*vals)
|
||||
return [TensorFlowTracer(self, v) for v in vals_out]
|
||||
|
||||
@ -342,7 +342,7 @@ class TensorFlowTrace(core.Trace):
|
||||
# (out_tracers) include TensorFlowTracer that were not passed through
|
||||
# its arguments (captured from the environment).
|
||||
vals = tuple(t.val for t in out_tracers)
|
||||
master = self.master
|
||||
master = self.main
|
||||
def todo(vals: Sequence[TfValOrUnit]):
|
||||
trace = TensorFlowTrace(master, core.cur_sublevel())
|
||||
return map(functools.partial(TensorFlowTracer, trace), vals)
|
||||
|
@ -59,10 +59,10 @@ def jet(fun, primals, series):
|
||||
|
||||
@lu.transformation
|
||||
def jet_fun(order, primals, series):
|
||||
with core.new_master(JetTrace) as master:
|
||||
master.order = order
|
||||
out_primals, out_terms = yield (master, primals, series), {}
|
||||
del master
|
||||
with core.new_main(JetTrace) as main:
|
||||
main.order = order
|
||||
out_primals, out_terms = yield (main, primals, series), {}
|
||||
del main
|
||||
out_terms = [[np.zeros_like(p)] * order if s is zero_series else s
|
||||
for p, s in zip(out_primals, out_terms)]
|
||||
yield out_primals, out_terms
|
||||
@ -115,7 +115,7 @@ class JetTrace(core.Trace):
|
||||
return JetTracer(self, val.primal, val.terms)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
order = self.master.order # pytype: disable=attribute-error
|
||||
order = self.main.order # pytype: disable=attribute-error
|
||||
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
|
||||
series_in = [[zero_term] * order if s is zero_series else s
|
||||
for s in series_in]
|
||||
@ -133,7 +133,7 @@ class JetTrace(core.Trace):
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
|
||||
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
|
||||
f_jet, out_tree_def = traceable(jet_subtrace(f, self.master), in_tree_def)
|
||||
f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
|
||||
update_params = call_param_updaters.get(call_primitive)
|
||||
new_params = (update_params(params, len(primals_and_series))
|
||||
if update_params else params)
|
||||
@ -145,7 +145,7 @@ class JetTrace(core.Trace):
|
||||
primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
|
||||
out, treedef = tree_flatten((primals, series))
|
||||
del primals, series
|
||||
master = self.master
|
||||
master = self.main
|
||||
def todo(x):
|
||||
primals, series = tree_unflatten(treedef, x)
|
||||
trace = JetTrace(master, core.cur_sublevel())
|
||||
|
@ -277,22 +277,22 @@ class Scope(object):
|
||||
|
||||
def start_subtrace(self):
|
||||
"""Starts a nested trace, returns the Trace object."""
|
||||
# TODO: This follows the __enter__ part of core.new_master.
|
||||
# TODO: This follows the __enter__ part of core.new_main.
|
||||
if config.omnistaging_enabled:
|
||||
level = core.thread_local_state.trace_state.trace_stack.next_level()
|
||||
master = core.MasterTrace(level, pe.JaxprTrace)
|
||||
master = core.MainTrace(level, pe.JaxprTrace)
|
||||
core.thread_local_state.trace_state.trace_stack.push(master)
|
||||
self._count_subtraces += 1
|
||||
return pe.JaxprTrace(master, core.cur_sublevel())
|
||||
else:
|
||||
level = core.thread_local_state.trace_state.trace_stack.next_level(False)
|
||||
master = core.MasterTrace(level, pe.JaxprTrace)
|
||||
master = core.MainTrace(level, pe.JaxprTrace)
|
||||
core.thread_local_state.trace_state.trace_stack.push(master, False)
|
||||
self._count_subtraces += 1
|
||||
return pe.JaxprTrace(master, core.cur_sublevel())
|
||||
|
||||
def end_subtrace(self):
|
||||
# TODO: This follows the __exit__ part of core.new_master
|
||||
# TODO: This follows the __exit__ part of core.new_main
|
||||
if config.omnistaging_enabled:
|
||||
core.thread_local_state.trace_state.trace_stack.pop()
|
||||
else:
|
||||
|
@ -46,7 +46,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
|
||||
|
||||
@lu.transformation
|
||||
def jvpfun(instantiate, primals, tangents):
|
||||
with core.new_master(JVPTrace) as master:
|
||||
with core.new_main(JVPTrace) as master:
|
||||
out_primals, out_tangents = yield (master, primals, tangents), {}
|
||||
del master
|
||||
if type(instantiate) is bool:
|
||||
@ -262,7 +262,7 @@ class JVPTrace(Trace):
|
||||
assert call_primitive.multiple_results
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
|
||||
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
|
||||
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.main),
|
||||
len(primals), tangent_tree_def)
|
||||
nz_tangents = [type(t) is not Zero for t in tangents]
|
||||
params = dict(params, name=wrap_name(params['name'], 'jvp'))
|
||||
@ -280,7 +280,7 @@ class JVPTrace(Trace):
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
||||
out, treedef = tree_flatten((primals, tangents))
|
||||
del primals, tangents
|
||||
master = self.master
|
||||
master = self.main
|
||||
def todo(x):
|
||||
primals, tangents = tree_unflatten(treedef, x)
|
||||
trace = JVPTrace(master, core.cur_sublevel())
|
||||
|
@ -60,7 +60,7 @@ def _batch_fun(axis_name, sum_match, in_dims, out_dims_thunk, out_dim_dests,
|
||||
canonicalize_axis(dim, np.ndim(val)) if isinstance(dim, int) else dim
|
||||
for val, dim in zip(in_vals, in_dims)]
|
||||
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
with core.new_master(BatchTrace) as master:
|
||||
with core.new_main(BatchTrace) as master:
|
||||
with core.extend_axis_env(axis_name, size, master):
|
||||
out_vals = yield (master, in_dims,) + in_vals, params
|
||||
del master
|
||||
@ -80,7 +80,7 @@ def batch_fun2(fun : lu.WrappedFun, in_dims):
|
||||
|
||||
@lu.transformation
|
||||
def _batch_fun2(in_dims, *in_vals, **params):
|
||||
with core.new_master(BatchTrace) as master:
|
||||
with core.new_main(BatchTrace) as master:
|
||||
out_vals = yield (master, in_dims,) + in_vals, params
|
||||
del master
|
||||
yield out_vals
|
||||
@ -142,7 +142,7 @@ class BatchTrace(Trace):
|
||||
axis_names = (axis_names,)
|
||||
for i, axis_name in enumerate(axis_names):
|
||||
frame = core.axis_frame(axis_name)
|
||||
if frame.master_trace is not self.master:
|
||||
if frame.tag is not self.main:
|
||||
continue
|
||||
# We run the split_axis rule with tracers, which is supposed to never
|
||||
# mix this axis name with another one. We will handle any invocations
|
||||
@ -168,13 +168,13 @@ class BatchTrace(Trace):
|
||||
if all(bdim is not_mapped for bdim in dims):
|
||||
return call_primitive.bind(f, *vals, **params)
|
||||
else:
|
||||
f, dims_out = batch_subtrace(f, self.master, dims)
|
||||
f, dims_out = batch_subtrace(f, self.main, dims)
|
||||
vals_out = call_primitive.bind(f, *vals, **params)
|
||||
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
master = self.master
|
||||
master = self.main
|
||||
def todo(vals):
|
||||
trace = BatchTrace(master, core.cur_sublevel())
|
||||
return map(partial(BatchTracer, trace), vals, dims)
|
||||
@ -191,14 +191,14 @@ class BatchTrace(Trace):
|
||||
for x, d, mapped_invar in zip(vals, dims, mapped_invars)]
|
||||
dims = tuple(not_mapped if d is not_mapped else max(0, d - mapped_invar)
|
||||
for d, mapped_invar in zip(dims, mapped_invars))
|
||||
f, dims_out = batch_subtrace(f, self.master, dims)
|
||||
f, dims_out = batch_subtrace(f, self.main, dims)
|
||||
vals_out = map_primitive.bind(f, *vals, **params)
|
||||
dims_out = tuple(d + 1 if d is not not_mapped else d for d in dims_out())
|
||||
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
|
||||
|
||||
def post_process_map(self, call_primitive, out_tracers, params):
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
master = self.master
|
||||
master = self.main
|
||||
def todo(vals):
|
||||
trace = BatchTrace(master, core.cur_sublevel())
|
||||
return [BatchTracer(trace, v, d + 1 if d is not not_mapped else d)
|
||||
@ -207,8 +207,8 @@ class BatchTrace(Trace):
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
fun, out_dims1 = batch_subtrace(fun, self.master, in_dims)
|
||||
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.master, in_dims)
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
|
||||
out_vals = prim.bind(fun, jvp, *in_vals)
|
||||
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
||||
if not fst:
|
||||
@ -218,8 +218,8 @@ class BatchTrace(Trace):
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
fun, out_dims1 = batch_subtrace(fun, self.master, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.master, in_dims)
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
||||
# TODO: Support collectives in custom_vjp?
|
||||
bwd = batch_fun(bwd, out_dims2, in_dims,
|
||||
axis_name='__unused_axis_name', sum_match=True)
|
||||
@ -392,7 +392,7 @@ def batch_jaxpr(jaxpr, size, batched, instantiate):
|
||||
@lu.transformation_with_aux
|
||||
def batched_traceable(size, batched, instantiate, *vals):
|
||||
in_dims = [0 if b else None for b in batched]
|
||||
with core.new_master(BatchTrace) as master:
|
||||
with core.new_main(BatchTrace) as master:
|
||||
trace = BatchTrace(master, core.cur_sublevel())
|
||||
ans = yield map(partial(BatchTracer, trace), vals, in_dims), {}
|
||||
out_tracers = map(trace.full_raise, ans)
|
||||
|
@ -80,7 +80,7 @@ def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
|
||||
logical_env_vals = [logical_env[k] for k in env_keys]
|
||||
# Make padded_env hashable
|
||||
padded_env = (env_keys, padded_env_vals)
|
||||
with core.new_master(MaskTrace) as master:
|
||||
with core.new_main(MaskTrace) as master:
|
||||
fun, out_shapes = mask_subtrace(fun, master, polymorphic_shapes, padded_env)
|
||||
out_vals = fun.call_wrapped(*(logical_env_vals + in_vals))
|
||||
del master
|
||||
@ -421,7 +421,7 @@ class MaskTrace(Trace):
|
||||
logical_env_vals = tuple(logical_env[k] for k in env_keys)
|
||||
# Make padded_env hashable
|
||||
padded_env = (env_keys, padded_env_vals)
|
||||
f, shapes_out = mask_subtrace(f, self.master, shapes, padded_env)
|
||||
f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env)
|
||||
if 'donated_invars' in params:
|
||||
params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
|
||||
params['donated_invars']))
|
||||
@ -430,7 +430,7 @@ class MaskTrace(Trace):
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
|
||||
master = self.master
|
||||
master = self.main
|
||||
def todo(vals):
|
||||
trace = MaskTrace(master, core.cur_sublevel())
|
||||
return map(partial(MaskTracer, trace), vals, shapes)
|
||||
|
@ -165,7 +165,7 @@ class JaxprTrace(Trace):
|
||||
# We use process_call to handle both call and map primitives.
|
||||
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
if not config.omnistaging_enabled:
|
||||
if (self.master.trace_type is StagingJaxprTrace
|
||||
if (self.main.trace_type is StagingJaxprTrace
|
||||
and primitive in staged_out_calls):
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
|
||||
@ -239,7 +239,7 @@ class JaxprTrace(Trace):
|
||||
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
|
||||
out = out_pv_consts + consts
|
||||
del consts, out_pv_consts
|
||||
master = self.master
|
||||
master = self.main
|
||||
|
||||
if primitive.map_primitive:
|
||||
sz = params['axis_size']
|
||||
@ -276,7 +276,7 @@ class JaxprTrace(Trace):
|
||||
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]]):
|
||||
"""Partially evaluate f on a sequence of PartialVals."""
|
||||
in_avals, in_consts = unzip2(pvals)
|
||||
f = trace_to_subjaxpr(f, self.master, False)
|
||||
f = trace_to_subjaxpr(f, self.main, False)
|
||||
f, aux = partial_eval_wrapper(f, tuple(in_avals))
|
||||
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
|
||||
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
@ -288,12 +288,12 @@ class JaxprTrace(Trace):
|
||||
# See comment at top of `JaxprTrace`. This method should be reachable
|
||||
# only when we stage out, and in that case we drop the custom differentiation
|
||||
# rules, because we do not need them.
|
||||
assert self.master.trace_type is StagingJaxprTrace
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
# See comment in the above process_custom_jvp_call method.
|
||||
assert self.master.trace_type is StagingJaxprTrace
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
|
||||
@ -417,7 +417,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
consts = [3, 6] # values for `ka` and `kb` constvars
|
||||
"""
|
||||
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
||||
with core.new_master(trace_type, bottom=bottom) as master:
|
||||
with core.new_main(trace_type, bottom=bottom) as master:
|
||||
fun = trace_to_subjaxpr(fun, master, instantiate)
|
||||
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
@ -427,7 +427,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
|
||||
|
||||
@lu.transformation
|
||||
def trace_to_subjaxpr(master: core.MasterTrace, instantiate: Union[bool, Sequence[bool]],
|
||||
def trace_to_subjaxpr(master: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
|
||||
pvals: Sequence[PartialVal]):
|
||||
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
|
||||
trace = JaxprTrace(master, core.cur_sublevel())
|
||||
@ -710,7 +710,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
typed_jaxpr, in_unknowns, instantiate=False) # type: ignore
|
||||
else:
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
||||
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type)
|
||||
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type)
|
||||
out_knowns = [not b for b in out_unknowns]
|
||||
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
|
||||
|
||||
@ -846,7 +846,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
msgs = self._progenitor_messages()
|
||||
msg = (f"Abstract tracer value passed to {name} for which a concrete value "
|
||||
"is required.\n"
|
||||
f"While tracing the function {self._trace.master.source_info}, "
|
||||
f"While tracing the function {self._trace.main.source_info}, "
|
||||
"this tracer originated from using JAX operations on these lines:"
|
||||
"\n\n" + "\n\n".join(msgs) + "\n\n"
|
||||
"See the above traceback for where this tracer was encountered.")
|
||||
@ -923,7 +923,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = [] # type: ignore
|
||||
|
||||
@property
|
||||
def frame(self): return self.master.jaxpr_stack[-1] # pytype: disable=attribute-error
|
||||
def frame(self): return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
|
||||
|
||||
def new_arg(self, aval):
|
||||
tracer = DynamicJaxprTracer(self, aval)
|
||||
@ -954,7 +954,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
return var
|
||||
|
||||
def instantiate_const(self, val):
|
||||
if (isinstance(val, Tracer) and val._trace.master is self.master
|
||||
if (isinstance(val, Tracer) and val._trace.main is self.main
|
||||
and val._trace.sublevel == self.sublevel):
|
||||
return val
|
||||
else:
|
||||
@ -974,7 +974,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.master, in_avals)
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
|
||||
if not jaxpr.eqns:
|
||||
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
@ -1000,7 +1000,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
for m, a in zip(params['mapped_invars'], in_avals)]
|
||||
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.master, reduced_in_avals)
|
||||
f, self.main, reduced_in_avals)
|
||||
out_avals = [core.unmapped_aval(params['axis_size'], a) for a in reduced_out_avals]
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
@ -1022,18 +1022,18 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
||||
assert config.omnistaging_enabled
|
||||
with core.new_master(DynamicJaxprTrace, dynamic=True) as master: # type: ignore
|
||||
master.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
master.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, master, in_avals)
|
||||
del master
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
del main
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, master: core.MasterTrace,
|
||||
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
||||
in_avals: Sequence[AbstractValue]):
|
||||
frame = JaxprStackFrame()
|
||||
with extend_jaxpr_stack(master, frame):
|
||||
trace = DynamicJaxprTrace(master, core.cur_sublevel())
|
||||
with extend_jaxpr_stack(main, frame):
|
||||
trace = DynamicJaxprTrace(main, core.cur_sublevel())
|
||||
in_tracers = map(trace.new_arg, in_avals)
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
out_tracers = map(trace.full_raise, ans)
|
||||
@ -1041,21 +1041,21 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, master: core.MasterTrace,
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
@contextlib.contextmanager
|
||||
def extend_jaxpr_stack(master, frame):
|
||||
master.jaxpr_stack = master.jaxpr_stack + (frame,)
|
||||
def extend_jaxpr_stack(main, frame):
|
||||
main.jaxpr_stack = main.jaxpr_stack + (frame,)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
assert frame is master.jaxpr_stack[-1]
|
||||
master.jaxpr_stack = master.jaxpr_stack[:-1]
|
||||
assert frame is main.jaxpr_stack[-1]
|
||||
main.jaxpr_stack = main.jaxpr_stack[:-1]
|
||||
|
||||
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
||||
assert config.omnistaging_enabled
|
||||
with core.new_base_master(DynamicJaxprTrace) as master: # type: ignore
|
||||
master.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
master.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, master, in_avals)
|
||||
del master
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
|
||||
del main
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
|
||||
@ -1064,7 +1064,7 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
|
||||
# custom_derivatives.py, which we work around by adding the EvalTrace.
|
||||
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
||||
assert config.omnistaging_enabled
|
||||
with core.new_master(core.EvalTrace, dynamic=True) as _: # type: ignore
|
||||
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
||||
return trace_to_jaxpr(fun, in_pvals)
|
||||
|
||||
def fun_sourceinfo(fun):
|
||||
@ -1090,7 +1090,7 @@ def omnistaging_enabler() -> None:
|
||||
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
||||
with core.new_master(JaxprTrace) as master:
|
||||
with core.new_main(JaxprTrace) as master:
|
||||
fun = trace_to_subjaxpr(fun, master, instantiate)
|
||||
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
|
@ -490,7 +490,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
if config.omnistaging_enabled:
|
||||
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
||||
else:
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type)
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
||||
|
||||
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
|
||||
# Fixpoint computation of unknown carry. Each iteration promotes
|
||||
@ -844,7 +844,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
if config.omnistaging_enabled:
|
||||
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
||||
else:
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type)
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
||||
|
||||
if index_uk:
|
||||
# When the branch index is unknown, we stage out the whole cond.
|
||||
@ -1517,7 +1517,7 @@ def _prune_zeros(ts):
|
||||
|
||||
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear, unroll):
|
||||
if not config.omnistaging_enabled and trace.master.trace_type is pe.StagingJaxprTrace:
|
||||
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace:
|
||||
params = dict(reverse=reverse, length=length, num_consts=num_consts,
|
||||
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
|
||||
unroll=unroll)
|
||||
@ -1531,7 +1531,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
if config.omnistaging_enabled:
|
||||
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
||||
else:
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type)
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
||||
|
||||
# Fixpoint computation of which carry are unknown (not a constant): either
|
||||
# unknown from init, or the carry out is unknown. Each iteration promotes
|
||||
|
Loading…
x
Reference in New Issue
Block a user