diff --git a/docs/conf.py b/docs/conf.py index 5b9cf01c6..de7eb7175 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -85,7 +85,7 @@ templates_path = ['_templates'] source_suffix = '.rst' # The master toctree document. -master_doc = 'index' +main_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -256,7 +256,7 @@ latex_elements = { # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'JAX.tex', 'JAX Documentation', + (main_doc, 'JAX.tex', 'JAX Documentation', 'The JAX authors', 'manual'), ] @@ -266,7 +266,7 @@ latex_documents = [ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'jax', 'JAX Documentation', + (main_doc, 'jax', 'JAX Documentation', [author], 1) ] @@ -277,7 +277,7 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'JAX', 'JAX Documentation', + (main_doc, 'JAX', 'JAX Documentation', author, 'JAX', 'One line description of project.', 'Miscellaneous'), ] diff --git a/jax/core.py b/jax/core.py index f919f5a3e..466a014d8 100644 --- a/jax/core.py +++ b/jax/core.py @@ -356,15 +356,15 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args): class Trace: - __slots__ = ['master', 'level', 'sublevel'] + __slots__ = ['main', 'level', 'sublevel'] - master: 'MasterTrace' + main: 'MainTrace' level: int sublevel: 'Sublevel' - def __init__(self, master: 'MasterTrace', sublevel: 'Sublevel') -> None: - self.master = master - self.level = master.level + def __init__(self, main: 'MainTrace', sublevel: 'Sublevel') -> None: + self.main = main + self.level = main.level self.sublevel = sublevel def full_raise(self, val) -> 'Tracer': @@ -372,7 +372,7 @@ class Trace: return self.pure(val) level = self.level sublevel = self.sublevel - if val._trace.master is self.master: + if val._trace.main is self.main: if val._trace.sublevel == sublevel: return val elif val._trace.sublevel < sublevel: @@ -589,7 +589,7 @@ class EvalTrace(Trace): process_map = process_call -class MasterTrace: +class MainTrace: level: int trace_type: Type[Trace] @@ -598,18 +598,18 @@ class MasterTrace: self.trace_type = trace_type def __repr__(self) -> str: - return "MasterTrace({},{})".format(self.level, self.trace_type.__name__) + return "MainTrace({},{})".format(self.level, self.trace_type.__name__) def __hash__(self) -> int: return hash((self.level, self.trace_type)) def __eq__(self, other: object) -> bool: - return (isinstance(other, MasterTrace) and + return (isinstance(other, MainTrace) and self.level == other.level and self.trace_type == other.trace_type) class TraceStack: - upward: List[MasterTrace] - downward: List[MasterTrace] + upward: List[MainTrace] + downward: List[MainTrace] def __init__(self): self.upward = [] @@ -621,11 +621,11 @@ class TraceStack: else: return len(self.upward) - def push(self, master_trace: MasterTrace, bottom: bool) -> None: + def push(self, main_trace: MainTrace, bottom: bool) -> None: if bottom: - self.downward.append(master_trace) + self.downward.append(main_trace) else: - self.upward.append(master_trace) + self.upward.append(main_trace) def pop(self, bottom: bool) -> None: if bottom: @@ -687,19 +687,19 @@ def cur_sublevel() -> Sublevel: return thread_local_state.trace_state.substack[-1] @contextmanager -def new_master(trace_type: Type[Trace], bottom=False) -> Generator[MasterTrace, None, None]: +def new_main(trace_type: Type[Trace], bottom=False) -> Generator[MainTrace, None, None]: level = thread_local_state.trace_state.trace_stack.next_level(bottom) - master = MasterTrace(level, trace_type) - thread_local_state.trace_state.trace_stack.push(master, bottom) + main = MainTrace(level, trace_type) + thread_local_state.trace_state.trace_stack.push(main, bottom) try: - yield master + yield main finally: thread_local_state.trace_state.trace_stack.pop(bottom) if check_leaks: - t = ref(master) - del master + t = ref(main) + del main if t() is not None: print(thread_local_state.trace_state.trace_stack) raise Exception('Leaked trace {}'.format(t())) @@ -728,7 +728,7 @@ def full_lower(val): def find_top_trace(xs) -> Optional[Trace]: top_trace = max((x._trace for x in xs if isinstance(x, Tracer)), key=attrgetter('level'), default=None) - return top_trace and type(top_trace)(top_trace.master, cur_sublevel()) + return top_trace and type(top_trace)(top_trace.main, cur_sublevel()) @contextmanager def initial_style_staging(): @@ -1116,7 +1116,7 @@ def process_env_traces(primitive: Union['CallPrimitive', 'MapPrimitive'], ans = max(tracers, key=lambda x: x._trace.level) else: break - trace = type(ans._trace)(ans._trace.master, cur_sublevel()) + trace = type(ans._trace)(ans._trace.main, cur_sublevel()) outs = map(trace.full_raise, outs) outs, cur_todo = primitive.post_process(trace, outs, params) todo.append(cur_todo) @@ -1436,24 +1436,24 @@ axis_frame = None def omnistaging_enabler() -> None: global thread_local_state, call_bind, find_top_trace, initial_style_staging, \ new_master, reset_trace_state, extend_axis_env, axis_frame, \ - new_base_master, eval_context, \ + new_base_main, eval_context, \ TraceStack, TraceState del initial_style_staging class TraceStack: - stack: List[MasterTrace] - dynamic: MasterTrace + stack: List[MainTrace] + dynamic: MainTrace def __init__(self): - eval_trace = MasterTrace(0, EvalTrace) + eval_trace = MainTrace(0, EvalTrace) self.stack = [eval_trace] self.dynamic = eval_trace def next_level(self) -> int: return len(self.stack) - def push(self, master_trace: MasterTrace) -> None: - self.stack.append(master_trace) + def push(self, main_trace: MainTrace) -> None: + self.stack.append(main_trace) def pop(self) -> None: self.stack.pop() @@ -1491,8 +1491,8 @@ def omnistaging_enabler() -> None: "Reset the global trace state and return True if it was already clean." if (thread_local_state.trace_state.substack != [Sublevel(0)] or thread_local_state.trace_state.axis_env != [] or - thread_local_state.trace_state.trace_stack.stack != [MasterTrace(0, EvalTrace)] or - thread_local_state.trace_state.trace_stack.dynamic != MasterTrace(0, EvalTrace)): + thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or + thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)): thread_local_state.trace_state.__init__() # type: ignore return False else: @@ -1512,55 +1512,55 @@ def omnistaging_enabler() -> None: def maybe_new_sublevel(trace): # dynamic traces run the WrappedFun, so we raise the sublevel for them dynamic = thread_local_state.trace_state.trace_stack.dynamic - return new_sublevel() if trace.master is dynamic else suppress() + return new_sublevel() if trace.main is dynamic else suppress() def find_top_trace(xs) -> Trace: - top_master = max((x._trace.master for x in xs if isinstance(x, Tracer)), + top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)), default=None, key=attrgetter('level')) dynamic = thread_local_state.trace_state.trace_stack.dynamic - top_master = (dynamic if top_master is None or dynamic.level > top_master.level - else top_master) - return top_master and top_master.trace_type(top_master, cur_sublevel()) # type: ignore + top_main = (dynamic if top_main is None or dynamic.level > top_main.level + else top_main) + return top_main and top_main.trace_type(top_main, cur_sublevel()) # type: ignore @contextmanager - def new_master(trace_type: Type[Trace], dynamic: bool = False, - ) -> Generator[MasterTrace, None, None]: + def new_main(trace_type: Type[Trace], dynamic: bool = False, + ) -> Generator[MainTrace, None, None]: stack = thread_local_state.trace_state.trace_stack level = stack.next_level() - master = MasterTrace(level, trace_type) - stack.push(master) + main = MainTrace(level, trace_type) + stack.push(main) if dynamic: - prev_dynamic, stack.dynamic = stack.dynamic, master + prev_dynamic, stack.dynamic = stack.dynamic, main try: - yield master + yield main finally: thread_local_state.trace_state.trace_stack.pop() if dynamic: stack.dynamic = prev_dynamic if check_leaks: - t = ref(master) - del master + t = ref(main) + del main if t() is not None: print(thread_local_state.trace_state.trace_stack) raise Exception('Leaked trace {}'.format(t())) @contextmanager - def new_base_master(trace_type: Type[Trace]) -> Generator[MasterTrace, None, None]: + def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]: stack = thread_local_state.trace_state.trace_stack - master = MasterTrace(0, trace_type) - prev_dynamic, stack.dynamic = stack.dynamic, master - prev_base, stack.stack[0] = stack.stack[0], master + main = MainTrace(0, trace_type) + prev_dynamic, stack.dynamic = stack.dynamic, main + prev_base, stack.stack[0] = stack.stack[0], main try: - yield master + yield main finally: stack.dynamic = prev_dynamic stack.stack[0] = prev_base @contextmanager def eval_context(): - with new_base_master(EvalTrace): + with new_base_main(EvalTrace): yield def bind(self, *args, **params): diff --git a/jax/experimental/callback.py b/jax/experimental/callback.py index e1d948dd4..82704cf7e 100644 --- a/jax/experimental/callback.py +++ b/jax/experimental/callback.py @@ -103,11 +103,11 @@ def callback_subtrace(master, *in_vals, **params): @lu.transformation def _callback_fun(callback, strip_calls, *in_vals, **params): - with core.new_master(CallbackTrace) as master: - master.callback = callback # NOTE: Is this OK? - master.strip_calls = strip_calls - out_vals = yield (master,) + in_vals, params - del master + with core.new_main(CallbackTrace) as main: + main.callback = callback # NOTE: Is this OK? + main.strip_calls = strip_calls + out_vals = yield (main,) + in_vals, params + del main yield out_vals def _check_callable(fun): @@ -144,15 +144,15 @@ class CallbackTrace(Trace): def process_primitive(self, primitive, tracers, params): vals_in = [t.val for t in tracers] - vals_out = self.master.callback(primitive, vals_in, params) # type: ignore + vals_out = self.main.callback(primitive, vals_in, params) # type: ignore if primitive.multiple_results: return [CallbackTracer(self, val) for val in vals_out] return CallbackTracer(self, vals_out) def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): - if self.master.strip_calls: # type: ignore + if self.main.strip_calls: # type: ignore return f.call_wrapped(*tracers) vals_in = [t.val for t in tracers] - f = callback_subtrace(f, self.master) + f = callback_subtrace(f, self.main) vals_out = call_primitive.bind(f, *vals_in, **params) return [CallbackTracer(self, val) for val in vals_out] diff --git a/jax/experimental/doubledouble.py b/jax/experimental/doubledouble.py index 770844270..be8e969dd 100644 --- a/jax/experimental/doubledouble.py +++ b/jax/experimental/doubledouble.py @@ -76,7 +76,7 @@ class DoublingTrace(core.Trace): assert call_primitive.multiple_results heads, tails = unzip2((t.head, t.tail) for t in tracers) nonzero_tails, in_tree_def = tree_flatten(tails) - f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.master), + f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main), len(heads), in_tree_def) name = params.get('name', f.__name__) new_params = dict(params, name=wrap_name(name, 'doubledouble'), @@ -109,7 +109,7 @@ def screen_nones(num_heads, in_tree_def, *heads_and_tails): @lu.transformation def doubling_transform(*args): - with core.new_master(DoublingTrace) as master: + with core.new_main(DoublingTrace) as master: trace = DoublingTrace(master, core.cur_sublevel()) in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args] outputs = yield in_tracers, {} diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 5d6774cd3..6437c78c5 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -212,7 +212,7 @@ def convert(fun, with_gradient=False): def _interpret_fun(fun: lu.WrappedFun, in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]: - with core.new_master(TensorFlowTrace) as master: + with core.new_main(TensorFlowTrace) as master: fun = _interpret_subtrace(fun, master) out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals) del master @@ -220,7 +220,7 @@ def _interpret_fun(fun: lu.WrappedFun, @lu.transformation -def _interpret_subtrace(master: core.MasterTrace, *in_vals: TfValOrUnit): +def _interpret_subtrace(master: core.MainTrace, *in_vals: TfValOrUnit): trace = TensorFlowTrace(master, core.cur_sublevel()) in_tracers = tuple(TensorFlowTracer(trace, val) for val in in_vals) outs = yield in_tracers, {} # type: Sequence[TfValOrUnit] @@ -332,7 +332,7 @@ class TensorFlowTrace(core.Trace): tracers: Sequence[TensorFlowTracer], params): assert call_primitive.multiple_results vals: Sequence[TfValOrUnit] = [t.val for t in tracers] - f = _interpret_subtrace(f, self.master) + f = _interpret_subtrace(f, self.main) vals_out: Sequence[TfValOrUnit] = f.call_wrapped(*vals) return [TensorFlowTracer(self, v) for v in vals_out] @@ -342,7 +342,7 @@ class TensorFlowTrace(core.Trace): # (out_tracers) include TensorFlowTracer that were not passed through # its arguments (captured from the environment). vals = tuple(t.val for t in out_tracers) - master = self.master + master = self.main def todo(vals: Sequence[TfValOrUnit]): trace = TensorFlowTrace(master, core.cur_sublevel()) return map(functools.partial(TensorFlowTracer, trace), vals) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 48e8cfdbd..410dfd548 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -59,10 +59,10 @@ def jet(fun, primals, series): @lu.transformation def jet_fun(order, primals, series): - with core.new_master(JetTrace) as master: - master.order = order - out_primals, out_terms = yield (master, primals, series), {} - del master + with core.new_main(JetTrace) as main: + main.order = order + out_primals, out_terms = yield (main, primals, series), {} + del main out_terms = [[np.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] yield out_primals, out_terms @@ -115,7 +115,7 @@ class JetTrace(core.Trace): return JetTracer(self, val.primal, val.terms) def process_primitive(self, primitive, tracers, params): - order = self.master.order # pytype: disable=attribute-error + order = self.main.order # pytype: disable=attribute-error primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) series_in = [[zero_term] * order if s is zero_series else s for s in series_in] @@ -133,7 +133,7 @@ class JetTrace(core.Trace): def process_call(self, call_primitive, f, tracers, params): primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) - f_jet, out_tree_def = traceable(jet_subtrace(f, self.master), in_tree_def) + f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) update_params = call_param_updaters.get(call_primitive) new_params = (update_params(params, len(primals_and_series)) if update_params else params) @@ -145,7 +145,7 @@ class JetTrace(core.Trace): primals, series = unzip2((t.primal, t.terms) for t in out_tracers) out, treedef = tree_flatten((primals, series)) del primals, series - master = self.master + master = self.main def todo(x): primals, series = tree_unflatten(treedef, x) trace = JetTrace(master, core.cur_sublevel()) diff --git a/jax/experimental/loops.py b/jax/experimental/loops.py index 98d9aa91b..f4fb91061 100644 --- a/jax/experimental/loops.py +++ b/jax/experimental/loops.py @@ -277,22 +277,22 @@ class Scope(object): def start_subtrace(self): """Starts a nested trace, returns the Trace object.""" - # TODO: This follows the __enter__ part of core.new_master. + # TODO: This follows the __enter__ part of core.new_main. if config.omnistaging_enabled: level = core.thread_local_state.trace_state.trace_stack.next_level() - master = core.MasterTrace(level, pe.JaxprTrace) + master = core.MainTrace(level, pe.JaxprTrace) core.thread_local_state.trace_state.trace_stack.push(master) self._count_subtraces += 1 return pe.JaxprTrace(master, core.cur_sublevel()) else: level = core.thread_local_state.trace_state.trace_stack.next_level(False) - master = core.MasterTrace(level, pe.JaxprTrace) + master = core.MainTrace(level, pe.JaxprTrace) core.thread_local_state.trace_state.trace_stack.push(master, False) self._count_subtraces += 1 return pe.JaxprTrace(master, core.cur_sublevel()) def end_subtrace(self): - # TODO: This follows the __exit__ part of core.new_master + # TODO: This follows the __exit__ part of core.new_main if config.omnistaging_enabled: core.thread_local_state.trace_state.trace_stack.pop() else: diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 24070d2b9..c3676f2b4 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -46,7 +46,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any: @lu.transformation def jvpfun(instantiate, primals, tangents): - with core.new_master(JVPTrace) as master: + with core.new_main(JVPTrace) as master: out_primals, out_tangents = yield (master, primals, tangents), {} del master if type(instantiate) is bool: @@ -262,7 +262,7 @@ class JVPTrace(Trace): assert call_primitive.multiple_results primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) nonzero_tangents, tangent_tree_def = tree_flatten(tangents) - f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), + f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.main), len(primals), tangent_tree_def) nz_tangents = [type(t) is not Zero for t in tangents] params = dict(params, name=wrap_name(params['name'], 'jvp')) @@ -280,7 +280,7 @@ class JVPTrace(Trace): primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) out, treedef = tree_flatten((primals, tangents)) del primals, tangents - master = self.master + master = self.main def todo(x): primals, tangents = tree_unflatten(treedef, x) trace = JVPTrace(master, core.cur_sublevel()) diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 0ba2489f2..06e7217bc 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -60,7 +60,7 @@ def _batch_fun(axis_name, sum_match, in_dims, out_dims_thunk, out_dim_dests, canonicalize_axis(dim, np.ndim(val)) if isinstance(dim, int) else dim for val, dim in zip(in_vals, in_dims)] size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} - with core.new_master(BatchTrace) as master: + with core.new_main(BatchTrace) as master: with core.extend_axis_env(axis_name, size, master): out_vals = yield (master, in_dims,) + in_vals, params del master @@ -80,7 +80,7 @@ def batch_fun2(fun : lu.WrappedFun, in_dims): @lu.transformation def _batch_fun2(in_dims, *in_vals, **params): - with core.new_master(BatchTrace) as master: + with core.new_main(BatchTrace) as master: out_vals = yield (master, in_dims,) + in_vals, params del master yield out_vals @@ -142,7 +142,7 @@ class BatchTrace(Trace): axis_names = (axis_names,) for i, axis_name in enumerate(axis_names): frame = core.axis_frame(axis_name) - if frame.master_trace is not self.master: + if frame.tag is not self.main: continue # We run the split_axis rule with tracers, which is supposed to never # mix this axis name with another one. We will handle any invocations @@ -168,13 +168,13 @@ class BatchTrace(Trace): if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) else: - f, dims_out = batch_subtrace(f, self.master, dims) + f, dims_out = batch_subtrace(f, self.main, dims) vals_out = call_primitive.bind(f, *vals, **params) return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())] def post_process_call(self, call_primitive, out_tracers, params): vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) - master = self.master + master = self.main def todo(vals): trace = BatchTrace(master, core.cur_sublevel()) return map(partial(BatchTracer, trace), vals, dims) @@ -191,14 +191,14 @@ class BatchTrace(Trace): for x, d, mapped_invar in zip(vals, dims, mapped_invars)] dims = tuple(not_mapped if d is not_mapped else max(0, d - mapped_invar) for d, mapped_invar in zip(dims, mapped_invars)) - f, dims_out = batch_subtrace(f, self.master, dims) + f, dims_out = batch_subtrace(f, self.main, dims) vals_out = map_primitive.bind(f, *vals, **params) dims_out = tuple(d + 1 if d is not not_mapped else d for d in dims_out()) return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)] def post_process_map(self, call_primitive, out_tracers, params): vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) - master = self.master + master = self.main def todo(vals): trace = BatchTrace(master, core.cur_sublevel()) return [BatchTracer(trace, v, d + 1 if d is not not_mapped else d) @@ -207,8 +207,8 @@ class BatchTrace(Trace): def process_custom_jvp_call(self, prim, fun, jvp, tracers): in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - fun, out_dims1 = batch_subtrace(fun, self.master, in_dims) - jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.master, in_dims) + fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) + jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims) out_vals = prim.bind(fun, jvp, *in_vals) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: @@ -218,8 +218,8 @@ class BatchTrace(Trace): def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - fun, out_dims1 = batch_subtrace(fun, self.master, in_dims) - fwd, out_dims2 = batch_subtrace(fwd, self.master, in_dims) + fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) + fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) # TODO: Support collectives in custom_vjp? bwd = batch_fun(bwd, out_dims2, in_dims, axis_name='__unused_axis_name', sum_match=True) @@ -392,7 +392,7 @@ def batch_jaxpr(jaxpr, size, batched, instantiate): @lu.transformation_with_aux def batched_traceable(size, batched, instantiate, *vals): in_dims = [0 if b else None for b in batched] - with core.new_master(BatchTrace) as master: + with core.new_main(BatchTrace) as master: trace = BatchTrace(master, core.cur_sublevel()) ans = yield map(partial(BatchTracer, trace), vals, in_dims), {} out_tracers = map(trace.full_raise, ans) diff --git a/jax/interpreters/masking.py b/jax/interpreters/masking.py index e894783af..cb88af7c2 100644 --- a/jax/interpreters/masking.py +++ b/jax/interpreters/masking.py @@ -80,7 +80,7 @@ def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes): logical_env_vals = [logical_env[k] for k in env_keys] # Make padded_env hashable padded_env = (env_keys, padded_env_vals) - with core.new_master(MaskTrace) as master: + with core.new_main(MaskTrace) as master: fun, out_shapes = mask_subtrace(fun, master, polymorphic_shapes, padded_env) out_vals = fun.call_wrapped(*(logical_env_vals + in_vals)) del master @@ -421,7 +421,7 @@ class MaskTrace(Trace): logical_env_vals = tuple(logical_env[k] for k in env_keys) # Make padded_env hashable padded_env = (env_keys, padded_env_vals) - f, shapes_out = mask_subtrace(f, self.master, shapes, padded_env) + f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env) if 'donated_invars' in params: params = dict(params, donated_invars=((False,) * len(logical_env_vals) + params['donated_invars'])) @@ -430,7 +430,7 @@ class MaskTrace(Trace): def post_process_call(self, call_primitive, out_tracers, params): vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers) - master = self.master + master = self.main def todo(vals): trace = MaskTrace(master, core.cur_sublevel()) return map(partial(MaskTracer, trace), vals, shapes) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b8c8f7cce..79289fdc1 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -165,7 +165,7 @@ class JaxprTrace(Trace): # We use process_call to handle both call and map primitives. def process_call(self, primitive, f: lu.WrappedFun, tracers, params): if not config.omnistaging_enabled: - if (self.master.trace_type is StagingJaxprTrace + if (self.main.trace_type is StagingJaxprTrace and primitive in staged_out_calls): tracers = map(self.instantiate_const_abstracted, tracers) @@ -239,7 +239,7 @@ class JaxprTrace(Trace): out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers) out = out_pv_consts + consts del consts, out_pv_consts - master = self.master + master = self.main if primitive.map_primitive: sz = params['axis_size'] @@ -276,7 +276,7 @@ class JaxprTrace(Trace): app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]]): """Partially evaluate f on a sequence of PartialVals.""" in_avals, in_consts = unzip2(pvals) - f = trace_to_subjaxpr(f, self.master, False) + f = trace_to_subjaxpr(f, self.main, False) f, aux = partial_eval_wrapper(f, tuple(in_avals)) out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux() out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) @@ -288,12 +288,12 @@ class JaxprTrace(Trace): # See comment at top of `JaxprTrace`. This method should be reachable # only when we stage out, and in that case we drop the custom differentiation # rules, because we do not need them. - assert self.master.trace_type is StagingJaxprTrace + assert self.main.trace_type is StagingJaxprTrace return fun.call_wrapped(*tracers) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): # See comment in the above process_custom_jvp_call method. - assert self.master.trace_type is StagingJaxprTrace + assert self.main.trace_type is StagingJaxprTrace return fun.call_wrapped(*tracers) @@ -417,7 +417,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], consts = [3, 6] # values for `ka` and `kb` constvars """ trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace) - with core.new_master(trace_type, bottom=bottom) as master: + with core.new_main(trace_type, bottom=bottom) as master: fun = trace_to_subjaxpr(fun, master, instantiate) jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) assert not env @@ -427,7 +427,7 @@ def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], @lu.transformation -def trace_to_subjaxpr(master: core.MasterTrace, instantiate: Union[bool, Sequence[bool]], +def trace_to_subjaxpr(master: core.MainTrace, instantiate: Union[bool, Sequence[bool]], pvals: Sequence[PartialVal]): assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals trace = JaxprTrace(master, core.cur_sublevel()) @@ -710,7 +710,7 @@ def _remat_partial_eval(trace, _, f, tracers, params): typed_jaxpr, in_unknowns, instantiate=False) # type: ignore else: jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr( - typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type) + typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) out_knowns = [not b for b in out_unknowns] out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns) @@ -846,7 +846,7 @@ class DynamicJaxprTracer(core.Tracer): msgs = self._progenitor_messages() msg = (f"Abstract tracer value passed to {name} for which a concrete value " "is required.\n" - f"While tracing the function {self._trace.master.source_info}, " + f"While tracing the function {self._trace.main.source_info}, " "this tracer originated from using JAX operations on these lines:" "\n\n" + "\n\n".join(msgs) + "\n\n" "See the above traceback for where this tracer was encountered.") @@ -923,7 +923,7 @@ class DynamicJaxprTrace(core.Trace): __slots__ = [] # type: ignore @property - def frame(self): return self.master.jaxpr_stack[-1] # pytype: disable=attribute-error + def frame(self): return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error def new_arg(self, aval): tracer = DynamicJaxprTracer(self, aval) @@ -954,7 +954,7 @@ class DynamicJaxprTrace(core.Trace): return var def instantiate_const(self, val): - if (isinstance(val, Tracer) and val._trace.master is self.master + if (isinstance(val, Tracer) and val._trace.main is self.main and val._trace.sublevel == self.sublevel): return val else: @@ -974,7 +974,7 @@ class DynamicJaxprTrace(core.Trace): def process_call(self, call_primitive, f, tracers, params): in_avals = [t.aval for t in tracers] - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.master, in_avals) + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals) if not jaxpr.eqns: return core.eval_jaxpr(jaxpr, consts, *tracers) out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] @@ -1000,7 +1000,7 @@ class DynamicJaxprTrace(core.Trace): for m, a in zip(params['mapped_invars'], in_avals)] with core.extend_axis_env(axis_name, axis_size, None): # type: ignore jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic( - f, self.master, reduced_in_avals) + f, self.main, reduced_in_avals) out_avals = [core.unmapped_aval(params['axis_size'], a) for a in reduced_out_avals] out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) @@ -1022,18 +1022,18 @@ class DynamicJaxprTrace(core.Trace): def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): assert config.omnistaging_enabled - with core.new_master(DynamicJaxprTrace, dynamic=True) as master: # type: ignore - master.source_info = fun_sourceinfo(fun.f) # type: ignore - master.jaxpr_stack = () # type: ignore - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, master, in_avals) - del master + with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore + main.source_info = fun_sourceinfo(fun.f) # type: ignore + main.jaxpr_stack = () # type: ignore + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) + del main return jaxpr, out_avals, consts -def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, master: core.MasterTrace, - in_avals: Sequence[AbstractValue]): +def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, + in_avals: Sequence[AbstractValue]): frame = JaxprStackFrame() - with extend_jaxpr_stack(master, frame): - trace = DynamicJaxprTrace(master, core.cur_sublevel()) + with extend_jaxpr_stack(main, frame): + trace = DynamicJaxprTrace(main, core.cur_sublevel()) in_tracers = map(trace.new_arg, in_avals) ans = fun.call_wrapped(*in_tracers) out_tracers = map(trace.full_raise, ans) @@ -1041,21 +1041,21 @@ def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, master: core.MasterTrace, return jaxpr, out_avals, consts @contextlib.contextmanager -def extend_jaxpr_stack(master, frame): - master.jaxpr_stack = master.jaxpr_stack + (frame,) +def extend_jaxpr_stack(main, frame): + main.jaxpr_stack = main.jaxpr_stack + (frame,) try: yield finally: - assert frame is master.jaxpr_stack[-1] - master.jaxpr_stack = master.jaxpr_stack[:-1] + assert frame is main.jaxpr_stack[-1] + main.jaxpr_stack = main.jaxpr_stack[:-1] def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): assert config.omnistaging_enabled - with core.new_base_master(DynamicJaxprTrace) as master: # type: ignore - master.source_info = fun_sourceinfo(fun.f) # type: ignore - master.jaxpr_stack = () # type: ignore - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, master, in_avals) - del master + with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore + main.source_info = fun_sourceinfo(fun.f) # type: ignore + main.jaxpr_stack = () # type: ignore + jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) + del main return jaxpr, out_avals, consts def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]): @@ -1064,7 +1064,7 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial # custom_derivatives.py, which we work around by adding the EvalTrace. # TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py assert config.omnistaging_enabled - with core.new_master(core.EvalTrace, dynamic=True) as _: # type: ignore + with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore return trace_to_jaxpr(fun, in_pvals) def fun_sourceinfo(fun): @@ -1090,7 +1090,7 @@ def omnistaging_enabler() -> None: def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], instantiate: Union[bool, Sequence[bool]] = False, ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]: - with core.new_master(JaxprTrace) as master: + with core.new_main(JaxprTrace) as master: fun = trace_to_subjaxpr(fun, master, instantiate) jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) assert not env diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 667bb5266..cb39d1735 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -490,7 +490,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: if config.omnistaging_enabled: partial_eval_jaxpr = pe.partial_eval_jaxpr else: - partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type) + partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type) cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts]) # Fixpoint computation of unknown carry. Each iteration promotes @@ -844,7 +844,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear): if config.omnistaging_enabled: partial_eval_jaxpr = pe.partial_eval_jaxpr else: - partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type) + partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type) if index_uk: # When the branch index is unknown, we stage out the whole cond. @@ -1517,7 +1517,7 @@ def _prune_zeros(ts): def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): - if not config.omnistaging_enabled and trace.master.trace_type is pe.StagingJaxprTrace: + if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: params = dict(reverse=reverse, length=length, num_consts=num_consts, num_carry=num_carry, jaxpr=jaxpr, linear=linear, unroll=unroll) @@ -1531,7 +1531,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, if config.omnistaging_enabled: partial_eval_jaxpr = pe.partial_eval_jaxpr else: - partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.master.trace_type) + partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type) # Fixpoint computation of which carry are unknown (not a constant): either # unknown from init, or the carry out is unknown. Each iteration promotes