A few more fixes for debug_info tests with direct_linearize.

This commit is contained in:
Dan Foreman-Mackey 2025-03-07 19:24:31 -05:00
parent 04696b4d7b
commit 36d515ed2c
3 changed files with 7 additions and 11 deletions

View File

@ -29,7 +29,6 @@ import warnings
import numpy as np
from jax._src import api
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core

View File

@ -876,8 +876,6 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
])
@unittest.skipIf(config.use_direct_linearize.value,
'broken with direct-linearize') # TODO(necula)
def test_vjp_of_nested_jit(self):
tracer_spy = TracerSpy()
def my_f(x, y):
@ -898,6 +896,8 @@ class DebugInfoTest(jtu.JaxTestCase):
# TODO(necula): result_paths
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
# TODO(necula): arg_names
"traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=,"
if config.use_direct_linearize.value else
"traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']",
],
expected_tracer_debug_infos=[
@ -1287,8 +1287,6 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None",
])
@unittest.skipIf(config.use_direct_linearize.value,
'broken with direct-linearize') # TODO(necula)
def test_grad_scan(self):
# Based on control_flow_test:testScanHigherOrderDifferentiation
tracer_spy = TracerSpy()
@ -1328,6 +1326,8 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=",
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,",
"traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=as_,,, result_paths="
if config.use_direct_linearize.value else
"traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=",
],
expected_tracer_debug_infos=[
@ -1597,8 +1597,6 @@ class DebugInfoTest(jtu.JaxTestCase):
],
)
@unittest.skipIf(config.use_direct_linearize.value,
'broken with direct-linearize') # TODO(necula)
def test_hessian(self):
tracer_spy = TracerSpy()
@ -1614,8 +1612,10 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
# TODO(necula): arg_names and result_paths?
"traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
"traced_for=jit, fun=my_f, arg_names=x,, result_paths=,"
if config.use_direct_linearize.value else
"traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[

View File

@ -51,9 +51,6 @@ from jax._src.interpreters import pxla
from jax._src.lax import parallel
from jax._src.lib import xla_extension
from jax._src.util import safe_map, safe_zip
from jax._src import util
from jax.api_util import flatten_fun_nokwargs, debug_info
from jax._src import linear_util as lu
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)