Merge pull request #26044 from hawkinsp:slots

PiperOrigin-RevId: 718537271
This commit is contained in:
jax authors 2025-01-22 14:50:07 -08:00
commit 1b6080d943
7 changed files with 14 additions and 0 deletions

View File

@ -560,6 +560,7 @@ def check_avals_context_mesh(avals, prim_name):
TracerType = TypeVar('TracerType', bound='Tracer')
class Trace(Generic[TracerType]):
__slots__ = ("__weakref__", "_invalidated")
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")

View File

@ -1808,6 +1808,8 @@ def _inline_literals(
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)
def __init__(self, debug_info: lu.TracingDebugInfo | None):
self.frame = JaxprStackFrame(debug_info)

View File

@ -461,6 +461,7 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName],
FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"])
class MapTrace(core.Trace):
__slots__ = ("axis_name", "emap_info")
def __init__(self, axis_name, emap_info):
self.emap_info = emap_info

View File

@ -1325,6 +1325,9 @@ class TensorFlowTrace(core.Trace):
those will introduce their own MainTrace, and any operations involving those
will be done on those traces, i.e., not a concern for TFT.
"""
__slots__ = ()
def to_tf_tracer(self, val: TfVal) -> TensorFlowTracer:
"""Lifts a non-Tracer into the TensorFlowTracer.
"""

View File

@ -199,6 +199,7 @@ class JetTracer(core.Tracer):
return self
class JetTrace(core.Trace):
__slots__ = ("tag", "parent_trace", "order")
def __init__(self, tag, parent_trace, order):
self.tag = tag

View File

@ -854,6 +854,8 @@ def _rem_singleton(x): return x.reshape(x.shape[1:])
def _add_singleton(x): return x.reshape(1, *x.shape)
class ShardMapTrace(core.Trace):
__slots__ = ("mesh", "check")
mesh: Mesh
check: bool
@ -1897,6 +1899,8 @@ class RewriteTracer(core.Tracer):
__repr__ = __str__ # for debuggers, like `p x`
class RewriteTrace(core.Trace):
__slots__ = ("parent_trace", "tag", "mesh")
parent_trace : core.Trace
tag : core.TraceTag
mesh: Mesh

View File

@ -297,6 +297,8 @@ class SparseTracer(core.Tracer):
class SparseTrace(core.Trace):
__slots__ = ("parent_trace", "tag", "spenv")
def __init__(self, parent_trace, tag, spenv):
self.parent_trace = parent_trace
self.tag = tag