[JAX] Change jax.core.Trace subclasses to call super().__init__().

Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().

hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)

PiperOrigin-RevId: 736555016
This commit is contained in:
Peter Hawkins 2025-03-13 10:27:00 -07:00 committed by jax authors
parent 14b9f48535
commit 8effa19734
8 changed files with 14 additions and 1 deletions

View File

@ -619,6 +619,9 @@ TracerType = TypeVar('TracerType', bound='Tracer')
class Trace(Generic[TracerType]):
__slots__ = ("__weakref__", "_invalidated")
def __init__(self):
self._invalidated = False
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
@ -626,7 +629,7 @@ class Trace(Generic[TracerType]):
self._invalidated = True
def is_valid(self):
return not hasattr(self, "_invalidated")
return not self._invalidated
def __repr__(self):
return '{}'.format(self.__class__.__name__)

View File

@ -459,6 +459,7 @@ def nonzero_tangent_outputs(f, store, *args, **kwargs):
class JVPTrace(Trace):
def __init__(self, parent_trace, tag):
super().__init__()
self.tag = tag
self.parent_trace = parent_trace
@ -640,6 +641,7 @@ call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
class LinearizeTrace(Trace):
def __init__(self, parent_trace, tangent_trace, tag=None):
super().__init__()
self.tag = core.TraceTag() if tag is None else tag
self.parent_trace = parent_trace
self.tangent_trace = tangent_trace

View File

@ -460,6 +460,7 @@ def get_sharding_for_vmap(axis_data, orig_sharding, axis):
class BatchTrace(Trace):
def __init__(self, parent_trace, tag, axis_data):
super().__init__()
self.parent_trace = parent_trace
assert isinstance(axis_data, AxisData)
self.axis_data = axis_data

View File

@ -141,6 +141,7 @@ class PartialVal(tuple):
class JaxprTrace(Trace['JaxprTracer']):
def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, tag:TraceTag):
super().__init__()
self.name_stack = name_stack
self.tag = tag
self.parent_trace = parent_trace
@ -1849,6 +1850,7 @@ class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame", "tag")
def __init__(self, debug_info: core.DebugInfo):
super().__init__()
self.frame = JaxprStackFrame(debug_info)
def invalidate(self):

View File

@ -463,6 +463,7 @@ class MapTrace(core.Trace):
__slots__ = ("axis_name", "emap_info")
def __init__(self, axis_name, emap_info):
super().__init__()
self.emap_info = emap_info
self.axis_name = axis_name

View File

@ -205,6 +205,7 @@ class JetTrace(core.Trace):
__slots__ = ("tag", "parent_trace", "order")
def __init__(self, tag, parent_trace, order):
super().__init__()
self.tag = tag
self.parent_trace = parent_trace
self.order = order

View File

@ -926,6 +926,7 @@ class ShardMapTrace(core.Trace):
context_mesh: AbstractMesh
def __init__(self, mesh, auto, check, context_mesh):
super().__init__()
self.mesh = mesh
self.auto = auto
self.check = check
@ -2042,6 +2043,7 @@ class RewriteTrace(core.Trace):
mesh: Mesh
def __init__(self, parent_trace, tag, mesh):
super().__init__()
self.parent_trace = parent_trace
self.tag = tag
self.mesh = mesh

View File

@ -301,6 +301,7 @@ class SparseTrace(core.Trace):
__slots__ = ("parent_trace", "tag", "spenv")
def __init__(self, parent_trace, tag, spenv):
super().__init__()
self.parent_trace = parent_trace
self.tag = tag
self.spenv = spenv