diff --git a/docs/jax.stages.rst b/docs/jax.stages.rst
index 804019ee1..d26def14d 100644
--- a/docs/jax.stages.rst
+++ b/docs/jax.stages.rst
@@ -12,6 +12,9 @@ Classes
    :members: trace, lower
    :special-members: __call__
 
+.. autoclass:: Traced
+   :members: jaxpr, out_info, lower
+
 .. autoclass:: Lowered
    :members: in_tree, out_tree, compile, as_text, compiler_ir, cost_analysis
 
diff --git a/jax/_src/stages.py b/jax/_src/stages.py
index 9d6df18f3..8c2074524 100644
--- a/jax/_src/stages.py
+++ b/jax/_src/stages.py
@@ -733,6 +733,12 @@ class Lowered(Stage):
 
 
 class Traced(Stage):
+  """Traced form of a function specialized to argument types and values.
+
+  A traced computation is ready for lowering. This class carries the
+  traced representation with the remaining information needed to later
+  lower, compile, and execute it.
+  """
   __slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
                "_args_flat", "_arg_names", "_num_consts"]
 
@@ -756,6 +762,7 @@ class Traced(Stage):
 
   def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
             _private_parameters: mlir.LoweringParameters | None = None):
+    """Lower to compiler input, returning a ``Lowered`` instance."""
     from jax._src.interpreters import pxla
     from jax._src import pjit
 
@@ -805,6 +812,8 @@ class Wrapped(Protocol):
   def lower(self, *args, **kwargs) -> Lowered:
     """Lower this function explicitly for the given arguments.
 
+    This is a shortcut for ``self.trace(*args, **kwargs).lower()``.
+
     A lowered function is staged out of Python and translated to a
     compiler's input language, possibly in a backend-dependent
     manner. It is ready for compilation but not yet compiled.