mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
A few more fixes for debug_info tests with direct_linearize.
This commit is contained in:
parent
04696b4d7b
commit
36d515ed2c
@ -29,7 +29,6 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
|
@ -876,8 +876,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
|
||||
])
|
||||
|
||||
@unittest.skipIf(config.use_direct_linearize.value,
|
||||
'broken with direct-linearize') # TODO(necula)
|
||||
def test_vjp_of_nested_jit(self):
|
||||
tracer_spy = TracerSpy()
|
||||
def my_f(x, y):
|
||||
@ -898,6 +896,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
# TODO(necula): result_paths
|
||||
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
|
||||
# TODO(necula): arg_names
|
||||
"traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=,"
|
||||
if config.use_direct_linearize.value else
|
||||
"traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
@ -1287,8 +1287,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None",
|
||||
])
|
||||
|
||||
@unittest.skipIf(config.use_direct_linearize.value,
|
||||
'broken with direct-linearize') # TODO(necula)
|
||||
def test_grad_scan(self):
|
||||
# Based on control_flow_test:testScanHigherOrderDifferentiation
|
||||
tracer_spy = TracerSpy()
|
||||
@ -1328,6 +1326,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=",
|
||||
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,",
|
||||
"traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,",
|
||||
"traced_for=jit, fun=my_f, arg_names=as_,,, result_paths="
|
||||
if config.use_direct_linearize.value else
|
||||
"traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
@ -1597,8 +1597,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.skipIf(config.use_direct_linearize.value,
|
||||
'broken with direct-linearize') # TODO(necula)
|
||||
def test_hessian(self):
|
||||
tracer_spy = TracerSpy()
|
||||
|
||||
@ -1614,8 +1612,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
||||
# TODO(necula): arg_names and result_paths?
|
||||
"traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
|
||||
"traced_for=jit, fun=my_f, arg_names=x,, result_paths=,"
|
||||
if config.use_direct_linearize.value else
|
||||
"traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
expected_tracer_debug_infos=[
|
||||
|
@ -51,9 +51,6 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lax import parallel
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src import util
|
||||
from jax.api_util import flatten_fun_nokwargs, debug_info
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
|
Loading…
x
Reference in New Issue
Block a user