mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Change jax.core.Trace subclasses to call super().__init__().
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__(). hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?) PiperOrigin-RevId: 736555016
This commit is contained in:
parent
14b9f48535
commit
8effa19734
@ -619,6 +619,9 @@ TracerType = TypeVar('TracerType', bound='Tracer')
|
||||
class Trace(Generic[TracerType]):
|
||||
__slots__ = ("__weakref__", "_invalidated")
|
||||
|
||||
def __init__(self):
|
||||
self._invalidated = False
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
@ -626,7 +629,7 @@ class Trace(Generic[TracerType]):
|
||||
self._invalidated = True
|
||||
|
||||
def is_valid(self):
|
||||
return not hasattr(self, "_invalidated")
|
||||
return not self._invalidated
|
||||
|
||||
def __repr__(self):
|
||||
return '{}'.format(self.__class__.__name__)
|
||||
|
@ -459,6 +459,7 @@ def nonzero_tangent_outputs(f, store, *args, **kwargs):
|
||||
|
||||
class JVPTrace(Trace):
|
||||
def __init__(self, parent_trace, tag):
|
||||
super().__init__()
|
||||
self.tag = tag
|
||||
self.parent_trace = parent_trace
|
||||
|
||||
@ -640,6 +641,7 @@ call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
|
||||
class LinearizeTrace(Trace):
|
||||
|
||||
def __init__(self, parent_trace, tangent_trace, tag=None):
|
||||
super().__init__()
|
||||
self.tag = core.TraceTag() if tag is None else tag
|
||||
self.parent_trace = parent_trace
|
||||
self.tangent_trace = tangent_trace
|
||||
|
@ -460,6 +460,7 @@ def get_sharding_for_vmap(axis_data, orig_sharding, axis):
|
||||
class BatchTrace(Trace):
|
||||
|
||||
def __init__(self, parent_trace, tag, axis_data):
|
||||
super().__init__()
|
||||
self.parent_trace = parent_trace
|
||||
assert isinstance(axis_data, AxisData)
|
||||
self.axis_data = axis_data
|
||||
|
@ -141,6 +141,7 @@ class PartialVal(tuple):
|
||||
class JaxprTrace(Trace['JaxprTracer']):
|
||||
|
||||
def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, tag:TraceTag):
|
||||
super().__init__()
|
||||
self.name_stack = name_stack
|
||||
self.tag = tag
|
||||
self.parent_trace = parent_trace
|
||||
@ -1849,6 +1850,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = ("frame", "tag")
|
||||
|
||||
def __init__(self, debug_info: core.DebugInfo):
|
||||
super().__init__()
|
||||
self.frame = JaxprStackFrame(debug_info)
|
||||
|
||||
def invalidate(self):
|
||||
|
@ -463,6 +463,7 @@ class MapTrace(core.Trace):
|
||||
__slots__ = ("axis_name", "emap_info")
|
||||
|
||||
def __init__(self, axis_name, emap_info):
|
||||
super().__init__()
|
||||
self.emap_info = emap_info
|
||||
self.axis_name = axis_name
|
||||
|
||||
|
@ -205,6 +205,7 @@ class JetTrace(core.Trace):
|
||||
__slots__ = ("tag", "parent_trace", "order")
|
||||
|
||||
def __init__(self, tag, parent_trace, order):
|
||||
super().__init__()
|
||||
self.tag = tag
|
||||
self.parent_trace = parent_trace
|
||||
self.order = order
|
||||
|
@ -926,6 +926,7 @@ class ShardMapTrace(core.Trace):
|
||||
context_mesh: AbstractMesh
|
||||
|
||||
def __init__(self, mesh, auto, check, context_mesh):
|
||||
super().__init__()
|
||||
self.mesh = mesh
|
||||
self.auto = auto
|
||||
self.check = check
|
||||
@ -2042,6 +2043,7 @@ class RewriteTrace(core.Trace):
|
||||
mesh: Mesh
|
||||
|
||||
def __init__(self, parent_trace, tag, mesh):
|
||||
super().__init__()
|
||||
self.parent_trace = parent_trace
|
||||
self.tag = tag
|
||||
self.mesh = mesh
|
||||
|
@ -301,6 +301,7 @@ class SparseTrace(core.Trace):
|
||||
__slots__ = ("parent_trace", "tag", "spenv")
|
||||
|
||||
def __init__(self, parent_trace, tag, spenv):
|
||||
super().__init__()
|
||||
self.parent_trace = parent_trace
|
||||
self.tag = tag
|
||||
self.spenv = spenv
|
||||
|
Loading…
x
Reference in New Issue
Block a user