From 36d515ed2c2bc552182dd6bd87f2144e50b0c773 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 7 Mar 2025 19:24:31 -0500 Subject: [PATCH] A few more fixes for debug_info tests with direct_linearize. --- jax/_src/pjit.py | 1 - tests/debug_info_test.py | 14 +++++++------- tests/pmap_test.py | 3 --- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 041b8a07c..69cd8e809 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -29,7 +29,6 @@ import warnings import numpy as np from jax._src import api -from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 0e10b255f..a39b53c3a 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -876,8 +876,6 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) - @unittest.skipIf(config.use_direct_linearize.value, - 'broken with direct-linearize') # TODO(necula) def test_vjp_of_nested_jit(self): tracer_spy = TracerSpy() def my_f(x, y): @@ -898,6 +896,8 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): result_paths "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=," + if config.use_direct_linearize.value else "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ @@ -1287,8 +1287,6 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None", ]) - @unittest.skipIf(config.use_direct_linearize.value, - 'broken with direct-linearize') # TODO(necula) def test_grad_scan(self): # Based on control_flow_test:testScanHigherOrderDifferentiation tracer_spy = TracerSpy() @@ -1328,6 +1326,8 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" + if config.use_direct_linearize.value else "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", ], expected_tracer_debug_infos=[ @@ -1597,8 +1597,6 @@ class DebugInfoTest(jtu.JaxTestCase): ], ) - @unittest.skipIf(config.use_direct_linearize.value, - 'broken with direct-linearize') # TODO(necula) def test_hessian(self): tracer_spy = TracerSpy() @@ -1614,8 +1612,10 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", + "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," + if config.use_direct_linearize.value else + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 0bddcaa78..af2d03e29 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -51,9 +51,6 @@ from jax._src.interpreters import pxla from jax._src.lax import parallel from jax._src.lib import xla_extension from jax._src.util import safe_map, safe_zip -from jax._src import util -from jax.api_util import flatten_fun_nokwargs, debug_info -from jax._src import linear_util as lu config.parse_flags_with_absl() jtu.request_cpu_devices(8)