|
|
|
@ -100,11 +100,12 @@ class TracerSpy:
|
|
|
|
|
try:
|
|
|
|
|
# We plan to do boolean conversion and catch the exception, but this works
|
|
|
|
|
# only for scalars
|
|
|
|
|
if t.shape:
|
|
|
|
|
t = jnp.sum(t)
|
|
|
|
|
if t:
|
|
|
|
|
t_scalar = t
|
|
|
|
|
while t_scalar.shape:
|
|
|
|
|
t_scalar = t_scalar[0]
|
|
|
|
|
if t_scalar:
|
|
|
|
|
pass
|
|
|
|
|
assert False, t
|
|
|
|
|
assert False, t_scalar
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.tracers.append((t, e))
|
|
|
|
|
|
|
|
|
@ -118,7 +119,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
tracer_spy: TracerSpy | None = None,
|
|
|
|
|
expected_tracer_debug_infos: list[str | re.Pattern] = [],
|
|
|
|
|
check_lowering: bool = True,
|
|
|
|
|
check_tracer_arg_name: bool = False,
|
|
|
|
|
expected_lowering_lines: list[str | re.Pattern] = [],
|
|
|
|
|
**kwargs) -> None:
|
|
|
|
|
"""Checks the expected debug info in all jaxprs, in spied tracers, and StableHLO.
|
|
|
|
@ -172,11 +172,14 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
for t, exc in tracer_spy.tracers:
|
|
|
|
|
if hasattr(t, "_debug_info"):
|
|
|
|
|
t_debug_info = _debug_info_to_string(t._debug_info)
|
|
|
|
|
if check_tracer_arg_name:
|
|
|
|
|
msg = str(exc)
|
|
|
|
|
m = re.match(r".* while tracing the function (.+) for ([^.]+)\.",
|
|
|
|
|
msg,
|
|
|
|
|
re.DOTALL)
|
|
|
|
|
msg = str(exc)
|
|
|
|
|
m = re.match(r".* while tracing the function (.+) for ([^.]+)\.",
|
|
|
|
|
msg,
|
|
|
|
|
re.DOTALL)
|
|
|
|
|
if m is None:
|
|
|
|
|
found_tracer_debug_infos.append(
|
|
|
|
|
f"{t_debug_info}, from None")
|
|
|
|
|
else:
|
|
|
|
|
self.assertIsNotNone(m, msg)
|
|
|
|
|
self.assertEqual(t._debug_info.func_src_info, m.group(1))
|
|
|
|
|
self.assertEqual(t._debug_info.traced_for, m.group(2))
|
|
|
|
@ -185,8 +188,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
re.DOTALL)
|
|
|
|
|
found_tracer_debug_infos.append(
|
|
|
|
|
f"{t_debug_info}, from {m.group(1) if m else None}")
|
|
|
|
|
else:
|
|
|
|
|
found_tracer_debug_infos.append(t_debug_info)
|
|
|
|
|
else:
|
|
|
|
|
found_tracer_debug_infos.append("None")
|
|
|
|
|
|
|
|
|
@ -646,9 +647,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
dict(a=1, b=2), 3,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=['c'],['d']"
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=result['c'],result['d']"
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, from x_dict['a']",
|
|
|
|
|
])
|
|
|
|
@ -669,10 +669,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
3,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
'traced_for=jit, fun=my_f, arg_names=a, result_paths=',
|
|
|
|
|
'traced_for=jit, fun=my_g, arg_names=b, result_paths=',
|
|
|
|
|
# TODO(necula): result_paths?
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=a, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=b, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=b, from b",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=a, from a",
|
|
|
|
@ -690,10 +690,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.jit(f),
|
|
|
|
|
{"ho": 1.}, {"hi": 2.}, 3., 4., z=5., w=6.,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=",
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from kwargs['w']",
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from args[0]",
|
|
|
|
@ -703,7 +702,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['w'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(\"kwargs\['z'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -721,10 +720,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
(1.,), {"hi": 2.}, 3., 4., 5., 6., # x, y, z, args[0], args[1], args[2]
|
|
|
|
|
t=11., w=12., # kwargs
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']",
|
|
|
|
|
],
|
|
|
|
@ -732,7 +730,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"y\['hi'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['t'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_jit_arg_names_static_argnames(self):
|
|
|
|
@ -748,10 +746,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
(1.,), {'hi': 2.}, 3., 4., # x, y, args[0], args[1]
|
|
|
|
|
z=5., w=6., a=7., b=8., # kwargs
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=",
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], from x[0]",
|
|
|
|
|
],
|
|
|
|
@ -760,7 +757,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['b'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(\"kwargs\['w'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_jit_result_info(self):
|
|
|
|
@ -770,13 +767,13 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.jit(f),
|
|
|
|
|
1., (2.,), [3.],
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=['a'],['b'][0][0]",
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=result['a'],result['b'][0][0]",
|
|
|
|
|
],
|
|
|
|
|
expected_lowering_lines=[
|
|
|
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"x\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"y\[0\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['a'\]\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['b'\]\[0\]\[0\]\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['a'\]\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['b'\]\[0\]\[0\]\"\}"),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_nested_jit(self):
|
|
|
|
@ -795,14 +792,35 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
2, 3,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=[\'c\']"
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']"
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v"
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, from x",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, from u"
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_nested_jit_with_const_and_unused_args(self):
|
|
|
|
|
def my_f(x, y): # y is unused
|
|
|
|
|
def my_g(u, v): # v is unused
|
|
|
|
|
return v + np.ones(v.shape, v.dtype)
|
|
|
|
|
|
|
|
|
|
return x + jax.jit(my_g)(y, x)
|
|
|
|
|
|
|
|
|
|
x = y = np.ones((8,), dtype=np.float32)
|
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
|
|
|
jax.jit(my_f),
|
|
|
|
|
x, y,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result"
|
|
|
|
|
],
|
|
|
|
|
expected_lowering_lines=[
|
|
|
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"),
|
|
|
|
|
re.compile(r".*call @my_g\(%arg.\) : \(tensor<8xf..>\)"),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def test_jvp_of_jit(self):
|
|
|
|
|
tracer_spy = TracerSpy()
|
|
|
|
|
def f(x, y, z):
|
|
|
|
@ -813,11 +831,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
# TODO(necula): arg_names, result_paths
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=None,None,None,None, result_paths=,,,",
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=,,,, result_paths=,,,",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x,y[0],z[0]",
|
|
|
|
|
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], from x",
|
|
|
|
|
],
|
|
|
|
|
expected_lowering_lines=[
|
|
|
|
|
# TODO(necula): missing arg_names
|
|
|
|
@ -840,13 +858,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])),
|
|
|
|
|
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=",
|
|
|
|
|
re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths="),
|
|
|
|
|
# TODO(necula): arg_names?
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=None,None,None,None, result_paths=['a'],['b'][0][0]",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=result",
|
|
|
|
|
re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths=result"),
|
|
|
|
|
# TODO(necula): arg_names? result_paths?
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=,,,, result_paths=['a'],['b'][0][0]",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y[0],z[0], from y[0]",
|
|
|
|
|
],
|
|
|
|
@ -854,6 +871,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
# TODO(necula): missing arg_names
|
|
|
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(unknown\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(unknown\)"),
|
|
|
|
|
# TODO(necula): result_paths?
|
|
|
|
|
re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
@ -873,13 +891,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
2., 3., 0.3,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=result[0],result[1]",
|
|
|
|
|
# TODO(necula): result_paths
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
|
|
|
|
|
# TODO(necula): arg_names
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=None,None,u,v, result_paths=['c'],['d']",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): missing debug info
|
|
|
|
|
"None",
|
|
|
|
@ -889,8 +906,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"x\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"y\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"res_ct\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[0\]\"}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[0\]\"}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[1\]\"}"),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_vjp_remat(self):
|
|
|
|
@ -909,11 +926,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
# TODO(necula): what are these flat_index components?
|
|
|
|
|
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]",
|
|
|
|
|
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
|
|
|
|
|
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
|
|
|
|
|
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=result[0],result[1][0][0][0][0][0]",
|
|
|
|
|
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"),
|
|
|
|
|
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"),
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x",
|
|
|
|
|
"traced_for=jit, fun=apply_fn, arg_names=inp, from inp",
|
|
|
|
@ -942,11 +958,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
42.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=a, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=a, result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): from None?
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=a, from None",
|
|
|
|
|
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y",
|
|
|
|
|
])
|
|
|
|
@ -976,10 +992,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
# TODO(necula): arg_names
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=None,a,b, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=custom_jvp fun, fun=my_g, arg_names=None,xy[0],xy[1], result_paths=",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=,a,b, result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=custom_jvp fun, fun=my_g, arg_names=,xy[0],xy[1], result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]",
|
|
|
|
|
# TODO(necula): from None
|
|
|
|
@ -1010,10 +1025,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
{"a" : 3.},
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=['a']",
|
|
|
|
|
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=['b']",
|
|
|
|
|
"traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=result['a']",
|
|
|
|
|
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=result['b']",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']",
|
|
|
|
|
# TODO(necula): from None?
|
|
|
|
@ -1041,10 +1055,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
(3., 3.),
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], result_paths=[0],[1]",
|
|
|
|
|
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from xy[0]",
|
|
|
|
|
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]",
|
|
|
|
@ -1097,11 +1110,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
x,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=",
|
|
|
|
|
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0][0][0],[0][0][1]",
|
|
|
|
|
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=result",
|
|
|
|
|
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=result",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0][0][0],result[0][0][1]",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from r",
|
|
|
|
|
"traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from x['c']",
|
|
|
|
@ -1129,11 +1141,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
x,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0]['c']",
|
|
|
|
|
"traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=['b']",
|
|
|
|
|
"traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=['c']",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0]['c']",
|
|
|
|
|
"traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']",
|
|
|
|
|
"traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): from None?
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, from None",
|
|
|
|
@ -1167,11 +1178,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
xy,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=['a']",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=result['a']",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y']",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y']"
|
|
|
|
|
"traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']"
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_cond(self):
|
|
|
|
@ -1192,13 +1203,13 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
0,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=",
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result",
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d"
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, from a",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, from c"
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_switch(self):
|
|
|
|
@ -1220,15 +1231,15 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
2,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=",
|
|
|
|
|
"traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=",
|
|
|
|
|
"traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=result",
|
|
|
|
|
"traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=result",
|
|
|
|
|
"traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=switch, fun=my_branch0, arg_names=x0",
|
|
|
|
|
"traced_for=switch, fun=my_branch1, arg_names=x1",
|
|
|
|
|
"traced_for=switch, fun=my_branch2, arg_names=x2"
|
|
|
|
|
"traced_for=switch, fun=my_branch0, arg_names=x0, from x0",
|
|
|
|
|
"traced_for=switch, fun=my_branch1, arg_names=x1, from x1",
|
|
|
|
|
"traced_for=switch, fun=my_branch2, arg_names=x2, from x2"
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_grad_cond_with_remat(self):
|
|
|
|
@ -1258,18 +1269,19 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
1., 2.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
'traced_for=jit, fun=my_f, arg_names=x,y, result_paths=',
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
|
|
|
# TODO(necula): arg_names? result_paths?
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=,",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=,",
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=,",
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=, result_paths=,",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=, result_paths=,",
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=,",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
'traced_for=cond, fun=my_true_branch, arg_names=a,b',
|
|
|
|
|
'traced_for=cond, fun=my_false_branch, arg_names=c,d',
|
|
|
|
|
'traced_for=checkpoint / remat, fun=my_g, arg_names=x,y',
|
|
|
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, from a",
|
|
|
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, from c",
|
|
|
|
|
# TODO(necula): from None
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None",
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_grad_scan(self):
|
|
|
|
@ -1302,30 +1314,29 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
c, as_,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]",
|
|
|
|
|
# TODO(necula): bad result paths
|
|
|
|
|
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]",
|
|
|
|
|
# TODO(necula): arg names, bad result paths
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,",
|
|
|
|
|
# TODO(necula): arg_names?
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=None,None,None, result_paths=,",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None, result_paths=,",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None, result_paths=",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None,None,None,None,None, result_paths=,",
|
|
|
|
|
"traced_for=checkpoint / remat, fun=to_remat, arg_names=None,None,None, result_paths=,",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=None,None,x,as_, result_paths=",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=,,, result_paths=,",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=",
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,",
|
|
|
|
|
"traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=the_grad, arg_names=c,as_",
|
|
|
|
|
"traced_for=scan, fun=f, arg_names=c,a",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,as_",
|
|
|
|
|
# TODO(necula): arg_names
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2]",
|
|
|
|
|
"traced_for=jit, fun=the_grad, arg_names=c,as_, from c",
|
|
|
|
|
"traced_for=scan, fun=f, arg_names=c,a, from c",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,as_, from x",
|
|
|
|
|
# TODO(necula): arg_names, and "from x"
|
|
|
|
|
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], from refs[0]",
|
|
|
|
|
],
|
|
|
|
|
expected_lowering_lines=[
|
|
|
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"c\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*tensor<f..> {jax.result_info = \"\[0\]\""),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"\[1\]\""),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*tensor<f..> {jax.result_info = \"result\[0\]\""),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""),
|
|
|
|
|
# TODO(necula): unnamed function?
|
|
|
|
|
re.compile(r".*func.func private @None"),
|
|
|
|
|
])
|
|
|
|
@ -1348,11 +1359,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
0,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
'traced_for=while_body, fun=my_body, arg_names=b, result_paths=',
|
|
|
|
|
'traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=',
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=while_body, fun=my_body, arg_names=b, result_paths=result",
|
|
|
|
|
"traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=while_cond, fun=my_cond, arg_names=a, from a",
|
|
|
|
|
"traced_for=while_body, fun=my_body, arg_names=b, from b",
|
|
|
|
@ -1370,14 +1380,14 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
3.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
|
|
|
# TODO(necula): bad arg_names, result_paths
|
|
|
|
|
'traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=[0][0],[0][1]',
|
|
|
|
|
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=result[0][0],result[0][1]",
|
|
|
|
|
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): the arg_names are not right
|
|
|
|
|
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1]",
|
|
|
|
|
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], from loop_carry[1]",
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -1389,14 +1399,14 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
5, 3.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=ub,x, result_paths=",
|
|
|
|
|
re.compile(r'traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths='),
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=ub,x, result_paths=result",
|
|
|
|
|
re.compile(r"traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths="),
|
|
|
|
|
# TODO(necula): arg_names and result_paths are not right
|
|
|
|
|
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=[0],[1],[2]",
|
|
|
|
|
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=result[0],result[1],result[2]",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): the arg_names are not right
|
|
|
|
|
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2]",
|
|
|
|
|
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], from loop_carry[2]",
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_scan(self):
|
|
|
|
@ -1413,10 +1423,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
np.arange(8, dtype=np.int32),
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=result[0],result[1]",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry"
|
|
|
|
|
])
|
|
|
|
@ -1433,7 +1442,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x"],
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, from x"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def test_vmap_of_nested_jit(self):
|
|
|
|
@ -1452,13 +1461,13 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
np.ones((8,), dtype=np.float32), np.zeros((8,), dtype=np.float32),
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=['c']",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): missing debug info
|
|
|
|
|
"None",
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v"
|
|
|
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, from u"
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_pmap(self):
|
|
|
|
@ -1471,11 +1480,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.pmap(my_f),
|
|
|
|
|
np.ones((jax.device_count(),), dtype=np.float32),
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x, result_paths="
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=result"
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x"
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x, from x"
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -1493,10 +1502,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
1., x, x, x, # x, y, args[0], args[1]
|
|
|
|
|
d=x, a=x, b=x, # kwargs
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=['u'],['v']",
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]",
|
|
|
|
|
],
|
|
|
|
@ -1508,8 +1516,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['u'\]\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['v'\]\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"),
|
|
|
|
|
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -1523,7 +1531,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.pmap(jax.grad(my_f)),
|
|
|
|
|
np.ones((jax.device_count(),), dtype=np.float32),
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
@ -1550,12 +1558,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
# TODO(necula): why this?
|
|
|
|
|
re.compile(r'traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*'),
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=['u'],['v']",
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=result['u'],result['v']",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): missing debug_info
|
|
|
|
|
'None'
|
|
|
|
|
"None"
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -1574,13 +1582,13 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))),
|
|
|
|
|
x, x_tan,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
'traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=[0],[1]',
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): missing debug_info
|
|
|
|
|
'None'
|
|
|
|
|
"None"
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -1597,13 +1605,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.jit(jax.hessian(jax.jit(my_f))),
|
|
|
|
|
x,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
# TODO(necula): arg_names and result_paths?
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=None,x, result_paths=,",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
|
|
|
|
|
],
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, from x",
|
|
|
|
|
],
|
|
|
|
@ -1626,11 +1633,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
0.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
# TODO(necula): missing result_paths
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=",
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y"
|
|
|
|
|
])
|
|
|
|
@ -1650,12 +1656,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
0.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
# TODO(necula): arg_names?
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
# TODO(necula): arg_names? result_paths?
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y"
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y",
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_remat_shard_map(self):
|
|
|
|
@ -1677,13 +1683,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jnp.arange(2, dtype=np.float32),
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
# TODO(necula): arg_names
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=",
|
|
|
|
|
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=",
|
|
|
|
|
# TODO(necula): arg_names, result_paths
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=checkpoint / remat, fun=my_f, arg_names=,, result_paths=",
|
|
|
|
|
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=shard_map, fun=my_f, arg_names=,, result_paths=",
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"None" # TODO(necula): missing
|
|
|
|
|
])
|
|
|
|
@ -1719,11 +1724,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
# TODO(necula): this should not be pointing into the JAX internals
|
|
|
|
|
re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"),
|
|
|
|
|
re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths="),
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]",
|
|
|
|
|
re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"),
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=result[0]",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=my_x",
|
|
|
|
|
"traced_for=pmap, fun=my_f, arg_names=my_x, from my_x",
|
|
|
|
|
],
|
|
|
|
|
check_lowering=False, # TODO(necula): warning during lowering
|
|
|
|
|
)
|
|
|
|
@ -1751,12 +1756,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
0.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=result[0],result[1]",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): no leaked tracer from my_g_dce?
|
|
|
|
|
"traced_for=custom_dce, fun=my_g, arg_names=x",
|
|
|
|
|
"traced_for=custom_dce, fun=my_g, arg_names=x, from x",
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def test_custom_dce_consts(self):
|
|
|
|
@ -1779,11 +1784,10 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
np.array(1.1234),
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
|
|
|
# TODO(necula): bad arg_names (why None), bad result_paths
|
|
|
|
|
'traced_for=custom_dce, fun=my_f, arg_names=None,x, result_paths=[0],[1]',
|
|
|
|
|
'traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]',
|
|
|
|
|
],
|
|
|
|
|
check_tracer_arg_name=True,
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
# TODO(necula): no leaked tracer from my_rule?
|
|
|
|
|
"traced_for=custom_dce, fun=my_f, arg_names=x, from x",
|
|
|
|
@ -1816,23 +1820,22 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
a, b,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=a,b, result_paths=[0],[1]",
|
|
|
|
|
re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths="),
|
|
|
|
|
re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths="),
|
|
|
|
|
re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths="),
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=a,b, result_paths=result[0],result[1]",
|
|
|
|
|
re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths=result"),
|
|
|
|
|
re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths=result"),
|
|
|
|
|
re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths=result"),
|
|
|
|
|
# TODO(necula): why pointers to internal functions, arg_names, result_paths?
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,None,x, result_paths='),
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,None,x, result_paths='),
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,x, result_paths='),
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,x, result_paths='),
|
|
|
|
|
'traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=None,b, result_paths=',
|
|
|
|
|
'traced_for=custom_linear_solve solve, fun=my_solve, arg_names=None,x, result_paths=',
|
|
|
|
|
'traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=None,x, result_paths=',
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve solve, fun=<lambda> at .*linalg.py:.*, arg_names=,,x, result_paths='),
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,,x, result_paths='),
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,x, result_paths='),
|
|
|
|
|
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,x, result_paths='),
|
|
|
|
|
"traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=,b, result_paths=result",
|
|
|
|
|
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=,x, result_paths=result",
|
|
|
|
|
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=,x, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x",
|
|
|
|
|
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x",
|
|
|
|
|
"traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=b",
|
|
|
|
|
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x, from x",
|
|
|
|
|
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x, from x",
|
|
|
|
|
"None", # TODO(necula): there are missing debug info
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
@ -1856,14 +1859,15 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
0.,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0],result[1]",
|
|
|
|
|
# TODO(necula): internal function?
|
|
|
|
|
re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=\[0\]"),
|
|
|
|
|
re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=result\[0\]"),
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=custom_root, fun=my_f, arg_names=x",
|
|
|
|
|
"traced_for=custom_root solve, fun=my_solve, arg_names=x",
|
|
|
|
|
"traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x",
|
|
|
|
|
"traced_for=custom_root, fun=my_f, arg_names=x, from x",
|
|
|
|
|
"traced_for=custom_root solve, fun=my_solve, arg_names=x, from x",
|
|
|
|
|
# TODO(necula): from None
|
|
|
|
|
"traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x, from None",
|
|
|
|
|
"None", # TODO(necula): there are missing debug info
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
@ -1891,13 +1895,13 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.jit(my_f), x,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=[0],[1]",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=result[0],result[1]",
|
|
|
|
|
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, result_paths=",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j",
|
|
|
|
|
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref",
|
|
|
|
|
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, from i",
|
|
|
|
|
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, from x_ref",
|
|
|
|
|
],
|
|
|
|
|
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
|
|
|
|
|
)
|
|
|
|
@ -1921,14 +1925,14 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jnp.arange(4, dtype=jnp.float32) - 2,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=input, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_f, arg_names=input, result_paths=result",
|
|
|
|
|
# TODO(necula): function source location points in JAX internals
|
|
|
|
|
# TODO(necula): arg_names and result_paths are wrong
|
|
|
|
|
re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="),
|
|
|
|
|
re.compile(r"traced_for=pallas_call index_map, fun=<lambda> at .*pallas.core.py:.*, arg_names=, result_paths="),
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref",
|
|
|
|
|
"traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref, from x_ref",
|
|
|
|
|
],
|
|
|
|
|
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
|
|
|
|
|
)
|
|
|
|
@ -1948,11 +1952,11 @@ class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
|
jax.jit(my_consts), x,
|
|
|
|
|
tracer_spy=tracer_spy,
|
|
|
|
|
expected_jaxpr_debug_infos=[
|
|
|
|
|
"traced_for=jit, fun=my_consts, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=composite, fun=my_consts, arg_names=x, result_paths=",
|
|
|
|
|
"traced_for=jit, fun=my_consts, arg_names=x, result_paths=result",
|
|
|
|
|
"traced_for=composite, fun=my_consts, arg_names=x, result_paths=result",
|
|
|
|
|
],
|
|
|
|
|
expected_tracer_debug_infos=[
|
|
|
|
|
"traced_for=composite, fun=my_consts, arg_names=x"])
|
|
|
|
|
"traced_for=composite, fun=my_consts, arg_names=x, from x"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|