mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #26044 from hawkinsp:slots
PiperOrigin-RevId: 718537271
This commit is contained in:
commit
1b6080d943
@ -560,6 +560,7 @@ def check_avals_context_mesh(avals, prim_name):
|
||||
TracerType = TypeVar('TracerType', bound='Tracer')
|
||||
|
||||
class Trace(Generic[TracerType]):
|
||||
__slots__ = ("__weakref__", "_invalidated")
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
raise NotImplementedError("must override")
|
||||
|
@ -1808,6 +1808,8 @@ def _inline_literals(
|
||||
|
||||
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = ("frame",)
|
||||
|
||||
def __init__(self, debug_info: lu.TracingDebugInfo | None):
|
||||
self.frame = JaxprStackFrame(debug_info)
|
||||
|
||||
|
@ -461,6 +461,7 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName],
|
||||
FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"])
|
||||
|
||||
class MapTrace(core.Trace):
|
||||
__slots__ = ("axis_name", "emap_info")
|
||||
|
||||
def __init__(self, axis_name, emap_info):
|
||||
self.emap_info = emap_info
|
||||
|
@ -1325,6 +1325,9 @@ class TensorFlowTrace(core.Trace):
|
||||
those will introduce their own MainTrace, and any operations involving those
|
||||
will be done on those traces, i.e., not a concern for TFT.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def to_tf_tracer(self, val: TfVal) -> TensorFlowTracer:
|
||||
"""Lifts a non-Tracer into the TensorFlowTracer.
|
||||
"""
|
||||
|
@ -199,6 +199,7 @@ class JetTracer(core.Tracer):
|
||||
return self
|
||||
|
||||
class JetTrace(core.Trace):
|
||||
__slots__ = ("tag", "parent_trace", "order")
|
||||
|
||||
def __init__(self, tag, parent_trace, order):
|
||||
self.tag = tag
|
||||
|
@ -854,6 +854,8 @@ def _rem_singleton(x): return x.reshape(x.shape[1:])
|
||||
def _add_singleton(x): return x.reshape(1, *x.shape)
|
||||
|
||||
class ShardMapTrace(core.Trace):
|
||||
__slots__ = ("mesh", "check")
|
||||
|
||||
mesh: Mesh
|
||||
check: bool
|
||||
|
||||
@ -1897,6 +1899,8 @@ class RewriteTracer(core.Tracer):
|
||||
__repr__ = __str__ # for debuggers, like `p x`
|
||||
|
||||
class RewriteTrace(core.Trace):
|
||||
__slots__ = ("parent_trace", "tag", "mesh")
|
||||
|
||||
parent_trace : core.Trace
|
||||
tag : core.TraceTag
|
||||
mesh: Mesh
|
||||
|
@ -297,6 +297,8 @@ class SparseTracer(core.Tracer):
|
||||
|
||||
class SparseTrace(core.Trace):
|
||||
|
||||
__slots__ = ("parent_trace", "tag", "spenv")
|
||||
|
||||
def __init__(self, parent_trace, tag, spenv):
|
||||
self.parent_trace = parent_trace
|
||||
self.tag = tag
|
||||
|
Loading…
x
Reference in New Issue
Block a user