diff --git a/CHANGELOG.md b/CHANGELOG.md
index e351f64d0..d8bb1478a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
 
 ## jax 0.4.38
 
+* Deprecations
+  * a number of APIs in the internal `jax.core` namespace have been deprecated, including
+    `ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
+    `Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
+    APIs of the same name in {mod}`jax.extend.core`; see the documentation for
+    {mod}`jax.extend` for information on the compatibility guarantees of these
+    semi-public extensions.
+
 ## jax 0.4.37 (Dec 9, 2024)
 
 This is a patch release of jax 0.4.36. Only "jax" was released at this version.
diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst
index 55094fc88..f89122f94 100644
--- a/docs/contributor_guide.rst
+++ b/docs/contributor_guide.rst
@@ -25,4 +25,3 @@ some of JAX's (extensible) internals.
 
    autodidax
    jep/index
-   jax_internal_api
diff --git a/docs/jax.extend.core.rst b/docs/jax.extend.core.rst
new file mode 100644
index 000000000..5f3ff0558
--- /dev/null
+++ b/docs/jax.extend.core.rst
@@ -0,0 +1,18 @@
+``jax.extend.core`` module
+==========================
+
+.. automodule:: jax.extend.core
+
+.. autosummary::
+  :toctree: _autosummary
+
+  ClosedJaxpr
+  Jaxpr
+  JaxprEqn
+  Literal
+  Primitive
+  Token
+  Var
+  array_types
+  jaxpr_as_fun
+  primitives
diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst
index 9cbee08e8..0d68013c9 100644
--- a/docs/jax.extend.rst
+++ b/docs/jax.extend.rst
@@ -11,6 +11,7 @@ Modules
 .. toctree::
   :maxdepth: 1
 
+  jax.extend.core
   jax.extend.ffi
   jax.extend.linear_util
   jax.extend.mlir
diff --git a/docs/jax_internal_api.rst b/docs/jax_internal_api.rst
deleted file mode 100644
index 1ece596d8..000000000
--- a/docs/jax_internal_api.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-Internal API reference
-======================
-
-core
-----
-
-.. currentmodule:: jax.core
-.. automodule:: jax.core
-
-.. autosummary::
-  :toctree: _autosummary
-
-   Jaxpr
-   ClosedJaxpr
diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py
index ef4e33ad0..c45bb8a9e 100644
--- a/jax/_src/cudnn/fused_attention_stablehlo.py
+++ b/jax/_src/cudnn/fused_attention_stablehlo.py
@@ -18,8 +18,8 @@ import json
 import math
 
 import jax
-from jax import core
 from jax import dtypes
+from jax._src import core
 from jax._src import dispatch
 from jax._src.custom_partitioning import custom_partitioning
 from jax._src.interpreters import batching
diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py
index 8a13399e3..f32067246 100644
--- a/jax/_src/cudnn/fusion.py
+++ b/jax/_src/cudnn/fusion.py
@@ -14,7 +14,7 @@
 
 import functools
 import jax
-from jax import core as jax_core
+from jax._src import core as jax_core
 from jax.interpreters import mlir
 from jax.interpreters.mlir import hlo
 from jax.interpreters.mlir import ir
diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py
index 4382cea91..ec9500c67 100644
--- a/jax/_src/pallas/mosaic/pallas_call_registration.py
+++ b/jax/_src/pallas/mosaic/pallas_call_registration.py
@@ -21,10 +21,9 @@ import tempfile
 from typing import Any
 
 import jax
-from jax import core as jax_core
 from jax import dtypes
 from jax._src import config
-from jax._src import core as jax_src_core
+from jax._src import core as jax_core
 from jax._src import sharding_impls
 from jax._src import tpu_custom_call
 from jax._src.interpreters import mlir
@@ -189,7 +188,7 @@ def pallas_call_tpu_lowering_rule(
   # Replace in_avals to physical avals.
   # This step is required for mapping logical types to physical types.
   # (e.g. PRNG key -> uint32[2])
-  physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in]
+  physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in]
   ctx = ctx.replace(avals_in=physical_avals)
 
   # Booleans are loaded into the kernel as integers.
diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD
index 3d6e82d44..e9461a5ce 100644
--- a/jax/_src/pallas/mosaic_gpu/BUILD
+++ b/jax/_src/pallas/mosaic_gpu/BUILD
@@ -44,6 +44,7 @@ pytype_strict_library(
     deps = [
         ":lowering",
         "//jax",
+        "//jax:core",
         "//jax:mlir",
         "//jax:mosaic_gpu",
         "//jax/_src/pallas",
diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
index 05785cb51..18d8baf6e 100644
--- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
+++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
@@ -23,7 +23,7 @@ from typing import Any
 import warnings
 
 import jax
-from jax import core as jax_core
+from jax._src import core as jax_core
 from jax._src.interpreters import mlir
 from jax._src.pallas import core as pallas_core
 from jax._src.pallas.mosaic_gpu import lowering
diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD
index a9babcba0..84fae3913 100644
--- a/jax/_src/pallas/triton/BUILD
+++ b/jax/_src/pallas/triton/BUILD
@@ -76,6 +76,7 @@ pytype_strict_library(
         ":lowering",
         "//jax",
         "//jax:config",
+        "//jax:core",
         "//jax:mlir",
         "//jax:util",
         "//jax/_src/lib",
diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py
index 67b0bd326..1805f8c09 100644
--- a/jax/_src/pallas/triton/pallas_call_registration.py
+++ b/jax/_src/pallas/triton/pallas_call_registration.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 import io
 from typing import Any
 
-from jax import core as jax_core
+import jax._src.core as jax_core
 from jax._src.interpreters import mlir
 from jax._src.lib.mlir import ir
 from jax._src.pallas import core as pallas_core
diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py
index 23fce50dc..b845a4079 100644
--- a/jax/_src/pallas/triton/primitives.py
+++ b/jax/_src/pallas/triton/primitives.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 from collections.abc import Sequence
 
 import jax
-from jax import core as jax_core
+from jax._src import core as jax_core
 from jax._src.lib.mlir.dialects import gpu as gpu_dialect
 from jax._src.lib.triton import dialect as tt_dialect
 from jax._src.pallas.triton import lowering
diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py
index ccd77af5b..9e54f62d9 100644
--- a/jax/_src/tpu_custom_call.py
+++ b/jax/_src/tpu_custom_call.py
@@ -29,7 +29,7 @@ import time
 from typing import Any
 
 import jax
-from jax import core
+from jax._src import core
 from jax._src import config
 from jax._src import sharding_impls
 from jax._src.interpreters import mlir
diff --git a/jax/core.py b/jax/core.py
index 8d7c546f0..4d1742bc2 100644
--- a/jax/core.py
+++ b/jax/core.py
@@ -23,7 +23,6 @@ from jax._src.core import (
   AxisSize as AxisSize,
   AxisName as AxisName,
   CallPrimitive as CallPrimitive,
-  ClosedJaxpr as ClosedJaxpr,
   ConcretizationTypeError as ConcretizationTypeError,
   DShapedArray as DShapedArray,
   DropVar as DropVar,
@@ -34,23 +33,18 @@ from jax._src.core import (
   InDBIdx as InDBIdx,
   InconclusiveDimensionOperation as InconclusiveDimensionOperation,
   InputType as InputType,
-  Jaxpr as Jaxpr,
   JaxprDebugInfo as JaxprDebugInfo,
-  JaxprEqn as JaxprEqn,
   JaxprPpContext as JaxprPpContext,
   JaxprPpSettings as JaxprPpSettings,
   JaxprTypeError as JaxprTypeError,
-  Literal as Literal,
   MapPrimitive as MapPrimitive,
   nonempty_axis_env as nonempty_axis_env_DO_NOT_USE,  # noqa: F401
   OpaqueTraceState as OpaqueTraceState,
   OutDBIdx as OutDBIdx,
   OutputType as OutputType,
   ParamDict as ParamDict,
-  Primitive as Primitive,
   ShapedArray as ShapedArray,
   TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
-  Token as Token,
   Trace as Trace,
   Tracer as Tracer,
   unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE,  # noqa: F401
@@ -59,7 +53,6 @@ from jax._src.core import (
   unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE,  # noqa: F401
   UnshapedArray as UnshapedArray,
   Value as Value,
-  Var as Var,
   abstract_token as abstract_token,
   aval_mapping_handlers as aval_mapping_handlers,
   call as call,
@@ -78,7 +71,6 @@ from jax._src.core import (
   eval_jaxpr as eval_jaxpr,
   extend_axis_env_nd as extend_axis_env_nd,
   find_top_trace as find_top_trace,
-  full_lower as full_lower,
   gensym as gensym,
   get_aval as get_aval,
   get_type as get_type,
@@ -86,10 +78,8 @@ from jax._src.core import (
   is_concrete as is_concrete,
   is_constant_dim as is_constant_dim,
   is_constant_shape as is_constant_shape,
-  jaxpr_as_fun as jaxpr_as_fun,
   jaxprs_in_params as jaxprs_in_params,
   join_effects as join_effects,
-  lattice_join as lattice_join,
   leaked_tracer_error as leaked_tracer_error,
   literalable_types as literalable_types,
   mapped_aval as mapped_aval,
@@ -101,7 +91,6 @@ from jax._src.core import (
   no_effects as no_effects,
   primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
   pytype_aval_mappings as pytype_aval_mappings,
-  raise_to_shaped as raise_to_shaped,
   raise_to_shaped_mappings as raise_to_shaped_mappings,
   reset_trace_state as reset_trace_state,
   set_current_trace as set_current_trace,
@@ -124,6 +113,37 @@ from jax._src.core import (
 
 from jax._src import core as _src_core
 _deprecations = {
+    # Added 2024-12-10
+    "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, "
+                    "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+                    _src_core.ClosedJaxpr),
+    "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.",
+                   _src_core.full_lower),
+    "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, "
+              "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+              _src_core.Jaxpr),
+    "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, "
+                 "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+                 _src_core.JaxprEqn),
+    "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, "
+                     "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+                     _src_core.jaxpr_as_fun),
+    "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.",
+                     _src_core.lattice_join),
+    "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, "
+                "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+                _src_core.Literal),
+    "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, "
+                  "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+                  _src_core.Primitive),
+    "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.",
+                        _src_core.raise_to_shaped),
+    "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, "
+              "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+              _src_core.Token),
+    "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
+            "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
+            _src_core.Var),
     # Added 2024-08-14
     "check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn),
     "check_type": ("jax.core.check_type is deprecated.", _src_core.check_type),
@@ -152,10 +172,21 @@ _deprecations = {
 
 import typing
 if typing.TYPE_CHECKING:
+  ClosedJaxpr = _src_core.ClosedJaxpr
+  Jaxpr = _src_core.Jaxpr
+  JaxprEqn = _src_core.JaxprEqn
+  Literal = _src_core.Literal
+  Primitive = _src_core.Primitive
+  Token = _src_core.Token
+  Var = _src_core.Var
   check_eqn = _src_core.check_eqn
   check_type = _src_core.check_type
   check_valid_jaxtype = _src_core.check_valid_jaxtype
+  full_lower = _src_core.full_lower
+  jaxpr_as_fun = _src_core.jaxpr_as_fun
+  lattice_join = _src_core.lattice_join
   non_negative_dim = _src_core.non_negative_dim
+  raise_to_shaped = _src_core.raise_to_shaped
 else:
   from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
   __getattr__ = _deprecation_getattr(__name__, _deprecations)
diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py
index d8774e932..b03c3a5b5 100644
--- a/jax/experimental/mosaic/gpu/core.py
+++ b/jax/experimental/mosaic/gpu/core.py
@@ -77,7 +77,7 @@ if RUNTIME_PATH and RUNTIME_PATH.exists():
   os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)
 
 
-mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
+mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p")
 mosaic_gpu_p.multiple_results = True
 
 
diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py
index f4fe0b904..6962ef78b 100644
--- a/jax/experimental/sparse/_lowerings.py
+++ b/jax/experimental/sparse/_lowerings.py
@@ -19,7 +19,7 @@ are used internally in GPU translation rules of higher-level primitives.
 
 from functools import partial
 
-from jax import core
+from jax._src import core
 from jax._src import dispatch
 from jax._src.interpreters import mlir
 from jax._src.lib import gpu_sparse
diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py
index 6c827325b..f9d28f5ff 100644
--- a/jax/experimental/sparse/nm.py
+++ b/jax/experimental/sparse/nm.py
@@ -14,7 +14,7 @@
 
 """N:M-sparsity associated primitives."""
 
-from jax import core
+from jax._src import core
 from jax._src import dispatch
 from jax._src.lax.lax import DotDimensionNumbers
 from jax._src.lib import gpu_sparse
diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py
index 286088eeb..3364c9be9 100644
--- a/tests/key_reuse_test.py
+++ b/tests/key_reuse_test.py
@@ -18,8 +18,8 @@ import operator
 
 import numpy as np
 import jax
-from jax import core
 import jax.numpy as jnp
+from jax._src import core
 from jax._src import prng
 from jax._src import random
 from jax._src import test_util as jtu
diff --git a/tests/pjit_test.py b/tests/pjit_test.py
index 5ca87aae6..5bb090435 100644
--- a/tests/pjit_test.py
+++ b/tests/pjit_test.py
@@ -3837,7 +3837,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
               constant_values= ((0.0, 0.0), (0.0, 0.0)))
 
     jaxpr = jax.make_jaxpr(trace_to_jaxpr)(x)
-    jax.core.jaxpr_as_fun(jaxpr)(x)
+    jax._src.core.jaxpr_as_fun(jaxpr)(x)
 
     jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')
     jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap')  # doesn't crash