applied simple find+sed for 'master' -> 'main' (#4174)

* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
This commit is contained in:
Matthew Johnson 2020-08-30 01:16:51 -07:00 committed by GitHub
parent 1a87fd3bc1
commit 6b6789a53b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 133 additions and 133 deletions

View File

@ -85,7 +85,7 @@ templates_path = ['_templates']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
main_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@ -256,7 +256,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'JAX.tex', 'JAX Documentation',
(main_doc, 'JAX.tex', 'JAX Documentation',
'The JAX authors', 'manual'),
]
@ -266,7 +266,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'jax', 'JAX Documentation',
(main_doc, 'jax', 'JAX Documentation',
[author], 1)
]
@ -277,7 +277,7 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'JAX', 'JAX Documentation',
(main_doc, 'JAX', 'JAX Documentation',
author, 'JAX', 'One line description of project.',
'Miscellaneous'),
]

View File

@ -356,15 +356,15 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
class Trace:
__slots__ = ['master', 'level', 'sublevel']
__slots__ = ['main', 'level', 'sublevel']
master: 'MasterTrace'
main: 'MainTrace'
level: int
sublevel: 'Sublevel'
def __init__(self, master: 'MasterTrace', sublevel: 'Sublevel') -> None:
self.master = master
self.level = master.level
def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
def full_raise(self, val) -> 'Tracer':
@ -372,7 +372,7 @@ class Trace:
return self.pure(val)
level = self.level
sublevel = self.sublevel
if val._trace.master is self.master:
if val._trace.main is self.main:
if val._trace.sublevel == sublevel:
return val
elif val._trace.sublevel < sublevel:
@ -589,7 +589,7 @@ class EvalTrace(Trace):
process_map = process_call
class MasterTrace:
class MainTrace:
level: int
trace_type: Type[Trace]
@ -598,18 +598,18 @@ class MasterTrace:
self.trace_type = trace_type
def __repr__(self) -> str:
return "MasterTrace({},{})".format(self.level, self.trace_type.__name__)
return "MainTrace({},{})".format(self.level, self.trace_type.__name__)
def __hash__(self) -> int:
return hash((self.level, self.trace_type))
def __eq__(self, other: object) -> bool:
return (isinstance(other, MasterTrace) and
return (isinstance(other, MainTrace) and
self.level == other.level and self.trace_type == other.trace_type)
class TraceStack:
upward: List[MasterTrace]
downward: List[MasterTrace]
upward: List[MainTrace]
downward: List[MainTrace]
def __init__(self):
self.upward = []
@ -621,11 +621,11 @@ class TraceStack:
else:
return len(self.upward)
def push(self, master_trace: MasterTrace, bottom: bool) -> None:
def push(self, main_trace: MainTrace, bottom: bool) -> None:
if bottom:
self.downward.append(master_trace)
self.downward.append(main_trace)
else:
self.upward.append(master_trace)
self.upward.append(main_trace)
def pop(self, bottom: bool) -> None:
if bottom:
@ -687,19 +687,19 @@ def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
@contextmanager
def new_master(trace_type: Type[Trace], bottom=False) -> Generator[MasterTrace, None, None]:
def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]:
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
master = MasterTrace(level, trace_type)
thread_local_state.trace_state.trace_stack.push(master, bottom)
main = MainTrace(level, trace_type)
thread_local_state.trace_state.trace_stack.push(main, bottom)
try:
yield master
yield main
finally:
thread_local_state.trace_state.trace_stack.pop(bottom)
if check_leaks:
t = ref(master)
del master
t = ref(main)
del main
if t() is not None:
print(thread_local_state.trace_state.trace_stack)
raise Exception('Leaked trace {}'.format(t()))
@ -728,7 +728,7 @@ def full_lower(val):
def find_top_trace(xs) -> Optional[Trace]:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'), default=None)
return top_trace and type(top_trace)(top_trace.master, cur_sublevel())
return top_trace and type(top_trace)(top_trace.main, cur_sublevel())
@contextmanager
def initial_style_staging():
@ -1116,7 +1116,7 @@ def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'],
ans = max(tracers, key=lambda x: x._trace.level)
else:
break
trace = type(ans._trace)(ans._trace.master, cur_sublevel())
trace = type(ans._trace)(ans._trace.main, cur_sublevel())
outs = map(trace.full_raise, outs)
outs, cur_todo = primitive.post_process(trace, outs, params)
todo.append(cur_todo)
@ -1436,24 +1436,24 @@ axis_frame = None
def omnistaging_enabler() -> None:
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
new_master, reset_trace_state, extend_axis_env, axis_frame, \
new_base_master, eval_context, \
new_base_main, eval_context, \
TraceStack, TraceState
del initial_style_staging
class TraceStack:
stack: List[MasterTrace]
dynamic: MasterTrace
stack: List[MainTrace]
dynamic: MainTrace
def __init__(self):
eval_trace = MasterTrace(0, EvalTrace)
eval_trace = MainTrace(0, EvalTrace)
self.stack = [eval_trace]
self.dynamic = eval_trace
def next_level(self) -> int:
return len(self.stack)
def push(self, master_trace: MasterTrace) -> None:
self.stack.append(master_trace)
def push(self, main_trace: MainTrace) -> None:
self.stack.append(main_trace)
def pop(self) -> None:
self.stack.pop()
@ -1491,8 +1491,8 @@ def omnistaging_enabler() -> None:
"Reset the global trace state and return True if it was already clean."
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
thread_local_state.trace_state.axis_env != [] or
thread_local_state.trace_state.trace_stack.stack != [MasterTrace(0, EvalTrace)] or
thread_local_state.trace_state.trace_stack.dynamic != MasterTrace(0, EvalTrace)):
thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or
thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)):
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
@ -1512,55 +1512,55 @@ def omnistaging_enabler() -> None:
def maybe_new_sublevel(trace):
# dynamic traces run the WrappedFun, so we raise the sublevel for them
dynamic = thread_local_state.trace_state.trace_stack.dynamic
return new_sublevel() if trace.master is dynamic else suppress()
return new_sublevel() if trace.main is dynamic else suppress()
def find_top_trace(xs) -> Trace:
top_master = max((x._trace.master for x in xs if isinstance(x, Tracer)),
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=None, key=attrgetter('level'))
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_master = (dynamic if top_master is None or dynamic.level > top_master.level
else top_master)
return top_master and top_master.trace_type(top_master, cur_sublevel()) # type: ignore
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)
return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore
@contextmanager
def new_master(trace_type: Type[Trace], dynamic: bool = False,
) -> Generator[MasterTrace, None, None]:
def new_main(trace_type: Type[Trace], dynamic: bool = False,
) -> Generator[MainTrace, None, None]:
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
master = MasterTrace(level, trace_type)
stack.push(master)
main = MainTrace(level, trace_type)
stack.push(main)
if dynamic:
prev_dynamic, stack.dynamic = stack.dynamic, master
prev_dynamic, stack.dynamic = stack.dynamic, main
try:
yield master
yield main
finally:
thread_local_state.trace_state.trace_stack.pop()
if dynamic:
stack.dynamic = prev_dynamic
if check_leaks:
t = ref(master)
del master
t = ref(main)
del main
if t() is not None:
print(thread_local_state.trace_state.trace_stack)
raise Exception('Leaked trace {}'.format(t()))
@contextmanager
def new_base_master(trace_type: Type[Trace]) -> Generator[MasterTrace, None, None]:
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
stack = thread_local_state.trace_state.trace_stack
master = MasterTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, master
prev_base, stack.stack[0] = stack.stack[0], master
main = MainTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
try:
yield master
yield main
finally:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base
@contextmanager
def eval_context():
with new_base_master(EvalTrace):
with new_base_main(EvalTrace):
yield
def bind(self, *args, **params):

View File

@ -103,11 +103,11 @@ def callback_subtrace(master, *in_vals, **params):
@lu.transformation
def _callback_fun(callback, strip_calls, *in_vals, **params):
with core.new_master(CallbackTrace) as master:
master.callback = callback # NOTE: Is this OK?
master.strip_calls = strip_calls
out_vals = yield (master,) + in_vals, params
del master
with core.new_main(CallbackTrace) as main:
main.callback = callback # NOTE: Is this OK?
main.strip_calls = strip_calls
out_vals = yield (main,) + in_vals, params
del main
yield out_vals
def _check_callable(fun):
@ -144,15 +144,15 @@ class CallbackTrace(Trace):
def process_primitive(self, primitive, tracers, params):
vals_in = [t.val for t in tracers]
vals_out = self.master.callback(primitive, vals_in, params) # type: ignore
vals_out = self.main.callback(primitive, vals_in, params) # type: ignore
if primitive.multiple_results:
return [CallbackTracer(self, val) for val in vals_out]
return CallbackTracer(self, vals_out)
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
if self.master.strip_calls: # type: ignore
if self.main.strip_calls: # type: ignore
return f.call_wrapped(*tracers)
vals_in = [t.val for t in tracers]
f = callback_subtrace(f, self.master)
f = callback_subtrace(f, self.main)
vals_out = call_primitive.bind(f, *vals_in, **params)
return [CallbackTracer(self, val) for val in vals_out]

View File

@ -76,7 +76,7 @@ class DoublingTrace(core.Trace):
assert call_primitive.multiple_results
heads, tails = unzip2((t.head, t.tail) for t in tracers)
nonzero_tails, in_tree_def = tree_flatten(tails)
f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.master),
f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main),
len(heads), in_tree_def)
name = params.get('name', f.__name__)
new_params = dict(params, name=wrap_name(name, 'doubledouble'),
@ -109,7 +109,7 @@ def screen_nones(num_heads, in_tree_def, *heads_and_tails):
@lu.transformation
def doubling_transform(*args):
with core.new_master(DoublingTrace) as master:
with core.new_main(DoublingTrace) as master:
trace = DoublingTrace(master, core.cur_sublevel())
in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args]
outputs = yield in_tracers, {}

View File

@ -212,7 +212,7 @@ def convert(fun, with_gradient=False):
def _interpret_fun(fun: lu.WrappedFun,
in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]:
with core.new_master(TensorFlowTrace) as master:
with core.new_main(TensorFlowTrace) as master:
fun = _interpret_subtrace(fun, master)
out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals)
del master
@ -220,7 +220,7 @@ def _interpret_fun(fun: lu.WrappedFun,
@lu.transformation
def _interpret_subtrace(master: core.MasterTrace, *in_vals: TfValOrUnit):
def _interpret_subtrace(master: core.MainTrace, *in_vals: TfValOrUnit):
trace = TensorFlowTrace(master, core.cur_sublevel())
in_tracers = tuple(TensorFlowTracer(trace, val) for val in in_vals)
outs = yield in_tracers, {} # type: Sequence[TfValOrUnit]
@ -332,7 +332,7 @@ class TensorFlowTrace(core.Trace):
tracers: Sequence[TensorFlowTracer], params):
assert call_primitive.multiple_results
vals: Sequence[TfValOrUnit] = [t.val for t in tracers]
f = _interpret_subtrace(f, self.master)
f = _interpret_subtrace(f, self.main)
vals_out: Sequence[TfValOrUnit] = f.call_wrapped(*vals)
return [TensorFlowTracer(self, v) for v in vals_out]
@ -342,7 +342,7 @@ class TensorFlowTrace(core.Trace):
# (out_tracers) include TensorFlowTracer that were not passed through
# its arguments (captured from the environment).
vals = tuple(t.val for t in out_tracers)
master = self.master
master = self.main
def todo(vals: Sequence[TfValOrUnit]):
trace = TensorFlowTrace(master, core.cur_sublevel())
return map(functools.partial(TensorFlowTracer, trace), vals)

View File

@ -59,10 +59,10 @@ def jet(fun, primals, series):
@lu.transformation
def jet_fun(order, primals, series):
with core.new_master(JetTrace) as master:
master.order = order
out_primals, out_terms = yield (master, primals, series), {}
del master
with core.new_main(JetTrace) as main:
main.order = order
out_primals, out_terms = yield (main, primals, series), {}
del main
out_terms = [[np.zeros_like(p)] * order if s is zero_series else s
for p, s in zip(out_primals, out_terms)]
yield out_primals, out_terms
@ -115,7 +115,7 @@ class JetTrace(core.Trace):
return JetTracer(self, val.primal, val.terms)
def process_primitive(self, primitive, tracers, params):
order = self.master.order # pytype: disable=attribute-error
order = self.main.order # pytype: disable=attribute-error
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
series_in = [[zero_term] * order if s is zero_series else s
for s in series_in]
@ -133,7 +133,7 @@ class JetTrace(core.Trace):
def process_call(self, call_primitive, f, tracers, params):
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
f_jet, out_tree_def = traceable(jet_subtrace(f, self.master), in_tree_def)
f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
update_params = call_param_updaters.get(call_primitive)
new_params = (update_params(params, len(primals_and_series))
if update_params else params)
@ -145,7 +145,7 @@ class JetTrace(core.Trace):
primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
out, treedef = tree_flatten((primals, series))
del primals, series
master = self.master
master = self.main
def todo(x):
primals, series = tree_unflatten(treedef, x)
trace = JetTrace(master, core.cur_sublevel())

View File

@ -277,22 +277,22 @@ class Scope(object):
def start_subtrace(self):
"""Starts a nested trace, returns the Trace object."""
# TODO: This follows the __enter__ part of core.new_master.
# TODO: This follows the __enter__ part of core.new_main.
if config.omnistaging_enabled:
level = core.thread_local_state.trace_state.trace_stack.next_level()
master = core.MasterTrace(level, pe.JaxprTrace)
master = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(master)
self._count_subtraces += 1
return pe.JaxprTrace(master, core.cur_sublevel())
else:
level = core.thread_local_state.trace_state.trace_stack.next_level(False)
master = core.MasterTrace(level, pe.JaxprTrace)
master = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(master, False)
self._count_subtraces += 1
return pe.JaxprTrace(master, core.cur_sublevel())
def end_subtrace(self):
# TODO: This follows the __exit__ part of core.new_master
# TODO: This follows the __exit__ part of core.new_main
if config.omnistaging_enabled:
core.thread_local_state.trace_state.trace_stack.pop()
else:

View File

@ -46,7 +46,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
@lu.transformation
def jvpfun(instantiate, primals, tangents):
with core.new_master(JVPTrace) as master:
with core.new_main(JVPTrace) as master:
out_primals, out_tangents = yield (master, primals, tangents), {}
del master
if type(instantiate) is bool:
@ -262,7 +262,7 @@ class JVPTrace(Trace):
assert call_primitive.multiple_results
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master),
f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.main),
len(primals), tangent_tree_def)
nz_tangents = [type(t) is not Zero for t in tangents]
params = dict(params, name=wrap_name(params['name'], 'jvp'))
@ -280,7 +280,7 @@ class JVPTrace(Trace):
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out, treedef = tree_flatten((primals, tangents))
del primals, tangents
master = self.master
master = self.main
def todo(x):
primals, tangents = tree_unflatten(treedef, x)
trace = JVPTrace(master, core.cur_sublevel())

View File

@ -60,7 +60,7 @@ def _batch_fun(axis_name, sum_match, in_dims, out_dims_thunk, out_dim_dests,
canonicalize_axis(dim, np.ndim(val)) if isinstance(dim, int) else dim
for val, dim in zip(in_vals, in_dims)]
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
with core.new_master(BatchTrace) as master:
with core.new_main(BatchTrace) as master:
with core.extend_axis_env(axis_name, size, master):
out_vals = yield (master, in_dims,) + in_vals, params
del master
@ -80,7 +80,7 @@ def batch_fun2(fun : lu.WrappedFun, in_dims):
@lu.transformation
def _batch_fun2(in_dims, *in_vals, **params):
with core.new_master(BatchTrace) as master:
with core.new_main(BatchTrace) as master:
out_vals = yield (master, in_dims,) + in_vals, params
del master
yield out_vals
@ -142,7 +142,7 @@ class BatchTrace(Trace):
axis_names = (axis_names,)
for i, axis_name in enumerate(axis_names):
frame = core.axis_frame(axis_name)
if frame.master_trace is not self.master:
if frame.tag is not self.main:
continue
# We run the split_axis rule with tracers, which is supposed to never
# mix this axis name with another one. We will handle any invocations
@ -168,13 +168,13 @@ class BatchTrace(Trace):
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
else:
f, dims_out = batch_subtrace(f, self.master, dims)
f, dims_out = batch_subtrace(f, self.main, dims)
vals_out = call_primitive.bind(f, *vals, **params)
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
def post_process_call(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.master
master = self.main
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
return map(partial(BatchTracer, trace), vals, dims)
@ -191,14 +191,14 @@ class BatchTrace(Trace):
for x, d, mapped_invar in zip(vals, dims, mapped_invars)]
dims = tuple(not_mapped if d is not_mapped else max(0, d - mapped_invar)
for d, mapped_invar in zip(dims, mapped_invars))
f, dims_out = batch_subtrace(f, self.master, dims)
f, dims_out = batch_subtrace(f, self.main, dims)
vals_out = map_primitive.bind(f, *vals, **params)
dims_out = tuple(d + 1 if d is not not_mapped else d for d in dims_out())
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
def post_process_map(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
master = self.master
master = self.main
def todo(vals):
trace = BatchTrace(master, core.cur_sublevel())
return [BatchTracer(trace, v, d + 1 if d is not not_mapped else d)
@ -207,8 +207,8 @@ class BatchTrace(Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
fun, out_dims1 = batch_subtrace(fun, self.master, in_dims)
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.master, in_dims)
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
out_vals = prim.bind(fun, jvp, *in_vals)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
@ -218,8 +218,8 @@ class BatchTrace(Trace):
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
fun, out_dims1 = batch_subtrace(fun, self.master, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.master, in_dims)
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
# TODO: Support collectives in custom_vjp?
bwd = batch_fun(bwd, out_dims2, in_dims,
axis_name='__unused_axis_name', sum_match=True)
@ -392,7 +392,7 @@ def batch_jaxpr(jaxpr, size, batched, instantiate):
@lu.transformation_with_aux
def batched_traceable(size, batched, instantiate, *vals):
in_dims = [0 if b else None for b in batched]
with core.new_master(BatchTrace) as master:
with core.new_main(BatchTrace) as master:
trace = BatchTrace(master, core.cur_sublevel())
ans = yield map(partial(BatchTracer, trace), vals, in_dims), {}
out_tracers = map(trace.full_raise, ans)

View File

@ -80,7 +80,7 @@ def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
logical_env_vals = [logical_env[k] for k in env_keys]
# Make padded_env hashable
padded_env = (env_keys, padded_env_vals)
with core.new_master(MaskTrace) as master:
with core.new_main(MaskTrace) as master:
fun, out_shapes = mask_subtrace(fun, master, polymorphic_shapes, padded_env)
out_vals = fun.call_wrapped(*(logical_env_vals + in_vals))
del master
@ -421,7 +421,7 @@ class MaskTrace(Trace):
logical_env_vals = tuple(logical_env[k] for k in env_keys)
# Make padded_env hashable
padded_env = (env_keys, padded_env_vals)
f, shapes_out = mask_subtrace(f, self.master, shapes, padded_env)
f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env)
if 'donated_invars' in params:
params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
params['donated_invars']))
@ -430,7 +430,7 @@ class MaskTrace(Trace):
def post_process_call(self, call_primitive, out_tracers, params):
vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
master = self.master
master = self.main
def todo(vals):
trace = MaskTrace(master, core.cur_sublevel())
return map(partial(MaskTracer, trace), vals, shapes)

View File

@ -165,7 +165,7 @@ class JaxprTrace(Trace):
# We use process_call to handle both call and map primitives.
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
if not config.omnistaging_enabled:
if (self.master.trace_type is StagingJaxprTrace
if (self.main.trace_type is StagingJaxprTrace
and primitive in staged_out_calls):
tracers = map(self.instantiate_const_abstracted, tracers)
@ -239,7 +239,7 @@ class JaxprTrace(Trace):
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
master = self.master
master = self.main
if primitive.map_primitive:
sz = params['axis_size']
@ -276,7 +276,7 @@ class JaxprTrace(Trace):
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]]):
"""Partially evaluate f on a sequence of PartialVals."""
in_avals, in_consts = unzip2(pvals)
f = trace_to_subjaxpr(f, self.master, False)
f = trace_to_subjaxpr(f, self.main, False)
f, aux = partial_eval_wrapper(f, tuple(in_avals))
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
@ -288,12 +288,12 @@ class JaxprTrace(Trace):
# See comment at top of `JaxprTrace`. This method should be reachable
# only when we stage out, and in that case we drop the custom differentiation
# rules, because we do not need them.
assert self.master.trace_type is StagingJaxprTrace
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
# See comment in the above process_custom_jvp_call method.
assert self.master.trace_type is StagingJaxprTrace
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
@ -417,7 +417,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
consts = [3, 6] # values for `ka` and `kb` constvars
"""
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
with core.new_master(trace_type, bottom=bottom) as master:
with core.new_main(trace_type, bottom=bottom) as master:
fun = trace_to_subjaxpr(fun, master, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
@ -427,7 +427,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
@lu.transformation
def trace_to_subjaxpr(master: core.MasterTrace, instantiate: Union[bool, Sequence[bool]],
def trace_to_subjaxpr(master: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
trace = JaxprTrace(master, core.cur_sublevel())
@ -710,7 +710,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
typed_jaxpr, in_unknowns, instantiate=False) # type: ignore
else:
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type)
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type)
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
@ -846,7 +846,7 @@ class DynamicJaxprTracer(core.Tracer):
msgs = self._progenitor_messages()
msg = (f"Abstract tracer value passed to {name} for which a concrete value "
"is required.\n"
f"While tracing the function {self._trace.master.source_info}, "
f"While tracing the function {self._trace.main.source_info}, "
"this tracer originated from using JAX operations on these lines:"
"\n\n" + "\n\n".join(msgs) + "\n\n"
"See the above traceback for where this tracer was encountered.")
@ -923,7 +923,7 @@ class DynamicJaxprTrace(core.Trace):
__slots__ = [] # type: ignore
@property
def frame(self): return self.master.jaxpr_stack[-1] # pytype: disable=attribute-error
def frame(self): return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
def new_arg(self, aval):
tracer = DynamicJaxprTracer(self, aval)
@ -954,7 +954,7 @@ class DynamicJaxprTrace(core.Trace):
return var
def instantiate_const(self, val):
if (isinstance(val, Tracer) and val._trace.master is self.master
if (isinstance(val, Tracer) and val._trace.main is self.main
and val._trace.sublevel == self.sublevel):
return val
else:
@ -974,7 +974,7 @@ class DynamicJaxprTrace(core.Trace):
def process_call(self, call_primitive, f, tracers, params):
in_avals = [t.aval for t in tracers]
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.master, in_avals)
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
if not jaxpr.eqns:
return core.eval_jaxpr(jaxpr, consts, *tracers)
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
@ -1000,7 +1000,7 @@ class DynamicJaxprTrace(core.Trace):
for m, a in zip(params['mapped_invars'], in_avals)]
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.master, reduced_in_avals)
f, self.main, reduced_in_avals)
out_avals = [core.unmapped_aval(params['axis_size'], a) for a in reduced_out_avals]
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
@ -1022,18 +1022,18 @@ class DynamicJaxprTrace(core.Trace):
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
assert config.omnistaging_enabled
with core.new_master(DynamicJaxprTrace, dynamic=True) as master: # type: ignore
master.source_info = fun_sourceinfo(fun.f) # type: ignore
master.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, master, in_avals)
del master
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f) # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del main
return jaxpr, out_avals, consts
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, master: core.MasterTrace,
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue]):
frame = JaxprStackFrame()
with extend_jaxpr_stack(master, frame):
trace = DynamicJaxprTrace(master, core.cur_sublevel())
with extend_jaxpr_stack(main, frame):
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = map(trace.new_arg, in_avals)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(trace.full_raise, ans)
@ -1041,21 +1041,21 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, master: core.MasterTrace,
return jaxpr, out_avals, consts
@contextlib.contextmanager
def extend_jaxpr_stack(master, frame):
master.jaxpr_stack = master.jaxpr_stack + (frame,)
def extend_jaxpr_stack(main, frame):
main.jaxpr_stack = main.jaxpr_stack + (frame,)
try:
yield
finally:
assert frame is master.jaxpr_stack[-1]
master.jaxpr_stack = master.jaxpr_stack[:-1]
assert frame is main.jaxpr_stack[-1]
main.jaxpr_stack = main.jaxpr_stack[:-1]
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
assert config.omnistaging_enabled
with core.new_base_master(DynamicJaxprTrace) as master: # type: ignore
master.source_info = fun_sourceinfo(fun.f) # type: ignore
master.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, master, in_avals)
del master
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f) # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del main
return jaxpr, out_avals, consts
def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]):
@ -1064,7 +1064,7 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
# custom_derivatives.py, which we work around by adding the EvalTrace.
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
assert config.omnistaging_enabled
with core.new_master(core.EvalTrace, dynamic=True) as _: # type: ignore
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
return trace_to_jaxpr(fun, in_pvals)
def fun_sourceinfo(fun):
@ -1090,7 +1090,7 @@ def omnistaging_enabler() -> None:
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
with core.new_master(JaxprTrace) as master:
with core.new_main(JaxprTrace) as master:
fun = trace_to_subjaxpr(fun, master, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env

View File

@ -490,7 +490,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
if config.omnistaging_enabled:
partial_eval_jaxpr = pe.partial_eval_jaxpr
else:
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type)
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
# Fixpoint computation of unknown carry. Each iteration promotes
@ -844,7 +844,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
if config.omnistaging_enabled:
partial_eval_jaxpr = pe.partial_eval_jaxpr
else:
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type)
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
if index_uk:
# When the branch index is unknown, we stage out the whole cond.
@ -1517,7 +1517,7 @@ def _prune_zeros(ts):
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
if not config.omnistaging_enabled and trace.master.trace_type is pe.StagingJaxprTrace:
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace:
params = dict(reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
unroll=unroll)
@ -1531,7 +1531,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
if config.omnistaging_enabled:
partial_eval_jaxpr = pe.partial_eval_jaxpr
else:
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type)
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
# Fixpoint computation of which carry are unknown (not a constant): either
# unknown from init, or the carry out is unknown. Each iteration promotes