From f4adcc650f023442b9eff8dbbf641a9e0601ccd7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 22 Jan 2025 16:17:54 -0500 Subject: [PATCH] Set __slots__ on core.Trace subclasses. This is easy to do and makes field accesses on Trace classes slightly faster. --- jax/_src/core.py | 1 + jax/_src/interpreters/partial_eval.py | 2 ++ jax/_src/interpreters/pxla.py | 1 + jax/experimental/jax2tf/jax2tf.py | 3 +++ jax/experimental/jet.py | 1 + jax/experimental/shard_map.py | 4 ++++ jax/experimental/sparse/transform.py | 2 ++ 7 files changed, 14 insertions(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index 0c7ad8479..4ac53378c 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -560,6 +560,7 @@ def check_avals_context_mesh(avals, prim_name): TracerType = TypeVar('TracerType', bound='Tracer') class Trace(Generic[TracerType]): + __slots__ = ("__weakref__", "_invalidated") def process_primitive(self, primitive, tracers, params): raise NotImplementedError("must override") diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ea1df444c..bd9ba286d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1808,6 +1808,8 @@ def _inline_literals( class DynamicJaxprTrace(core.Trace): + __slots__ = ("frame",) + def __init__(self, debug_info: lu.TracingDebugInfo | None): self.frame = JaxprStackFrame(debug_info) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 0c7c01762..b918d3ed5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -461,6 +461,7 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName], FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"]) class MapTrace(core.Trace): + __slots__ = ("axis_name", "emap_info") def __init__(self, axis_name, emap_info): self.emap_info = emap_info diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index ad3f4ff15..cf7091c4d 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1325,6 +1325,9 @@ class TensorFlowTrace(core.Trace): those will introduce their own MainTrace, and any operations involving those will be done on those traces, i.e., not a concern for TFT. """ + + __slots__ = () + def to_tf_tracer(self, val: TfVal) -> TensorFlowTracer: """Lifts a non-Tracer into the TensorFlowTracer. """ diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 2da1ba1d4..9b3ce9ec8 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -199,6 +199,7 @@ class JetTracer(core.Tracer): return self class JetTrace(core.Trace): + __slots__ = ("tag", "parent_trace", "order") def __init__(self, tag, parent_trace, order): self.tag = tag diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 64c6b2b81..131b8a964 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -854,6 +854,8 @@ def _rem_singleton(x): return x.reshape(x.shape[1:]) def _add_singleton(x): return x.reshape(1, *x.shape) class ShardMapTrace(core.Trace): + __slots__ = ("mesh", "check") + mesh: Mesh check: bool @@ -1897,6 +1899,8 @@ class RewriteTracer(core.Tracer): __repr__ = __str__ # for debuggers, like `p x` class RewriteTrace(core.Trace): + __slots__ = ("parent_trace", "tag", "mesh") + parent_trace : core.Trace tag : core.TraceTag mesh: Mesh diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 0f255747e..76d4d957e 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -297,6 +297,8 @@ class SparseTracer(core.Tracer): class SparseTrace(core.Trace): + __slots__ = ("parent_trace", "tag", "spenv") + def __init__(self, parent_trace, tag, spenv): self.parent_trace = parent_trace self.tag = tag