diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index 041b8a07c..69cd8e809 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -29,7 +29,6 @@ import warnings
 import numpy as np
 
 from jax._src import api
-from jax._src import ad_util
 from jax._src import api_util
 from jax._src import config
 from jax._src import core
diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py
index 0e10b255f..a39b53c3a 100644
--- a/tests/debug_info_test.py
+++ b/tests/debug_info_test.py
@@ -876,8 +876,6 @@ class DebugInfoTest(jtu.JaxTestCase):
             re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
         ])
 
-  @unittest.skipIf(config.use_direct_linearize.value,
-                   'broken with direct-linearize')  # TODO(necula)
   def test_vjp_of_nested_jit(self):
     tracer_spy = TracerSpy()
     def my_f(x, y):
@@ -898,6 +896,8 @@ class DebugInfoTest(jtu.JaxTestCase):
             # TODO(necula): result_paths
             "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
             # TODO(necula): arg_names
+            "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=,"
+            if config.use_direct_linearize.value else
             "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']",
         ],
         expected_tracer_debug_infos=[
@@ -1287,8 +1287,6 @@ class DebugInfoTest(jtu.JaxTestCase):
             "traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None",
         ])
 
-  @unittest.skipIf(config.use_direct_linearize.value,
-                   'broken with direct-linearize')  # TODO(necula)
   def test_grad_scan(self):
       # Based on control_flow_test:testScanHigherOrderDifferentiation
     tracer_spy = TracerSpy()
@@ -1328,6 +1326,8 @@ class DebugInfoTest(jtu.JaxTestCase):
             "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=",
             "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,",
             "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,",
+            "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths="
+            if config.use_direct_linearize.value else
             "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=",
         ],
         expected_tracer_debug_infos=[
@@ -1597,8 +1597,6 @@ class DebugInfoTest(jtu.JaxTestCase):
         ],
     )
 
-  @unittest.skipIf(config.use_direct_linearize.value,
-                   'broken with direct-linearize')  # TODO(necula)
   def test_hessian(self):
     tracer_spy = TracerSpy()
 
@@ -1614,8 +1612,10 @@ class DebugInfoTest(jtu.JaxTestCase):
         expected_jaxpr_debug_infos=[
             "traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
             # TODO(necula): arg_names and result_paths?
-            "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
             "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
+            "traced_for=jit, fun=my_f, arg_names=x,, result_paths=,"
+            if config.use_direct_linearize.value else
+            "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
         ],
         tracer_spy=tracer_spy,
         expected_tracer_debug_infos=[
diff --git a/tests/pmap_test.py b/tests/pmap_test.py
index 0bddcaa78..af2d03e29 100644
--- a/tests/pmap_test.py
+++ b/tests/pmap_test.py
@@ -51,9 +51,6 @@ from jax._src.interpreters import pxla
 from jax._src.lax import parallel
 from jax._src.lib import xla_extension
 from jax._src.util import safe_map, safe_zip
-from jax._src import util
-from jax.api_util import flatten_fun_nokwargs, debug_info
-from jax._src import linear_util as lu
 
 config.parse_flags_with_absl()
 jtu.request_cpu_devices(8)