[better_errors] Add more debug info test coverage

Try to cover the tracing of almost all JAX higher-order
primitives. Some of the tests added show missing debug info,
marked with TODO. Fixes will come separately.

Had to expand the helper functions _check_tracers_and_jaxprs to
use regular expressions for matching because some debug info
still contains non-deterministic elements.
This commit is contained in:
George Necula 2025-01-25 07:34:26 +02:00
parent 55efd4b225
commit e4d5427d13
3 changed files with 409 additions and 47 deletions

View File

@ -46,6 +46,7 @@ jax_multiplatform_test(
srcs = ["debug_info_test.py"],
enable_configs = ["tpu_v3_2x2"],
deps = [
"//jax:experimental",
"//jax:pallas",
"//jax:pallas_gpu",
"//jax:pallas_gpu_ops",

View File

@ -5830,7 +5830,7 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' sin '), 3)
self.assertEqual(jaxpr_text.count(' cos '), 3)
# Six calls to dot_general in the backward pass because we save the primal
# matmuls and only compure the backward pass ones (two for each primal one).
# matmuls and only compute the backward pass ones (two for each primal one).
self.assertEqual(jaxpr_text.count(' dot_'), 6)
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,

View File

@ -23,17 +23,27 @@ from typing import Any
from absl.testing import absltest, parameterized
import jax
from jax import ad_checkpoint
from jax import lax
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
from jax.experimental import checkify
import jax.experimental.custom_dce
from jax.experimental import pallas as pl
import jax.numpy as jnp
import jax.scipy as jsp
from jax._src import api_util
from jax._src.ad_checkpoint import saved_residuals
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src.compilation_cache import is_persistent_cache_enabled
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
from jax.experimental import pallas as pl
import jax.numpy as jnp
from jax._src.lax.control_flow import for_loop
import numpy as np
config.parse_flags_with_absl()
@ -68,19 +78,25 @@ def _debug_info_to_string(dbg: api_util.TracingDebugInfo | core.JaxprDebugInfo |
return res
@jtu.with_config(jax_mutable_array_checks=True)
class DebugInfoTest(jtu.JaxTestCase):
def _check_tracers_and_jaxprs(self, traceable: Any,
*args,
expected_jaxpr_debug_infos: list[str],
expected_jaxpr_debug_infos: list[str | re.Pattern],
leaked_tracers: list[core.Tracer] = [],
expected_tracer_debug_infos: list[str] = [],
expected_tracer_debug_infos: list[str | re.Pattern] = [],
check_lowering: bool = True,
**kwargs):
"""Checks for expected debug info in all jaxprs, and in leaked tracers.
The `traceable.trace(*args, **kwargs)` is traced to a Jaxpr, and the
debug infos in the nested Jaxprs are first converted to strings using
`_debug_info_to_string` and then compared against `expected_jaxpr_debug_infos`.
An element of `expected_jaxpr_debug_infos` can be a string, in which case
it is looked up by equality, or a `re.Pattern` (the result of `re.compile`)
in which case it is looked up by `.match()`. All elements of
`expected_jaxpr_debug_infos` must appear, and all Jaxprs must be matched.
One way in which the debug info is used in JAX is for leaked tracer
description, or for ConcretizationErrors. Optionally,
@ -91,17 +107,39 @@ class DebugInfoTest(jtu.JaxTestCase):
traced = traceable.trace(*args, **kwargs)
all_jaxprs = _collect_jaxprs(traced.jaxpr.jaxpr)
jaxprs_debug_infos = [_debug_info_to_string(j.debug_info)
for j in all_jaxprs]
self.assertEqual(expected_jaxpr_debug_infos, jaxprs_debug_infos)
found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info)
for j in all_jaxprs]
self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos)
self._check_tracers(leaked_tracers, expected_tracer_debug_infos)
# Run the lowering because this one exercises more code with debug_info
# TODO(necula): check the lowering
if check_lowering:
traced.lower()
def _check_tracers(self,
leaked_tracers: list[core.Tracer],
expected_tracer_debug_infos: list[str]):
leaked_tracer_debug_infos = [_debug_info_to_string(t._debug_info) if hasattr(t, "_debug_info") else "None"
for t in leaked_tracers]
self.assertEqual(expected_tracer_debug_infos, leaked_tracer_debug_infos)
expected_tracer_debug_infos: list[str | re.Pattern]):
found_leaked_tracer_debug_infos = [
_debug_info_to_string(t._debug_info) if hasattr(t, "_debug_info") else "None"
for t in leaked_tracers]
self._check_matches(expected_tracer_debug_infos, found_leaked_tracer_debug_infos)
def _check_matches(self,
expected: list[str | re.Pattern],
found: list[str]):
expected_and_found = set()
unexpected: set[str] = set()
for debug_info in found:
for exp_re in expected:
ok = exp_re.match(debug_info) if isinstance(exp_re, re.Pattern) else exp_re == debug_info
if ok:
expected_and_found.add(exp_re)
break
else:
unexpected.add(debug_info)
self.assertEmpty(unexpected) # found unexpected debug_info
self.assertEmpty([e for e in expected if e not in expected_and_found]) # expected element that was not found
def test_debug_info_basic(self):
def my_f(x, y, z, w):
@ -556,7 +594,6 @@ class DebugInfoTest(jtu.JaxTestCase):
def test_simple_jit(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x_dict, y):
leaked_tracers.append(x_dict["a"])
return dict(c=x_dict["a"] + x_dict["b"], d=y)
@ -572,9 +609,34 @@ class DebugInfoTest(jtu.JaxTestCase):
'traced_for=jit, fun=my_f, arg_names=("x_dict[\'a\']", "x_dict[\'b\']", \'y\')',
])
def test_jit_with_static_argnums(self):
leaked_tracers: list[core.Tracer] = []
@functools.partial(jax.jit, static_argnums=(1,))
def my_f(a, d):
leaked_tracers.append(a)
return a
def my_g(a, d=1):
leaked_tracers.append(a)
return my_f(a, d)
self._check_tracers_and_jaxprs(
jax.jit(my_g),
3,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_g, arg_names=('a',), result_paths=('',)",
"traced_for=jit, fun=my_f, arg_names=('a',), result_paths=()"
],
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_g, arg_names=('a',)",
# TODO(necula): bad arg name
"traced_for=jit, fun=my_f, arg_names=('args[0]',)"
])
def test_nested_jit(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x, y):
leaked_tracers.append(x)
@ -599,7 +661,6 @@ class DebugInfoTest(jtu.JaxTestCase):
def test_vjp_of_nested_jit(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x, y):
leaked_tracers.append(x)
@ -616,8 +677,7 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=('x', 'y', 'res_ct'), result_paths=('[0]', '[1]')",
# TODO(necula): missing debug info
'None',
'None'
"None",
],
expected_tracer_debug_infos=[
# TODO(necula): missing debug info
@ -627,7 +687,6 @@ class DebugInfoTest(jtu.JaxTestCase):
def test_vmap_of_nested_jit(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x, y):
leaked_tracers.append(x)
@ -652,9 +711,38 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=jit, fun=my_g, arg_names=('u', 'v')"
])
def test_custom_vmap(self):
leaked_tracers: list[core.Tracer] = []
@jax.custom_batching.custom_vmap
def my_f(xdict):
x = xdict["x"]
leaked_tracers.append(x)
return dict(a=jnp.sin(x))
@my_f.def_vmap
def my_rule(axis_size, in_batched, xys):
xs = xys["x"]
leaked_tracers.append(xs)
xs_batched, = in_batched
self.assertEqual(xs_batched["x"], True)
self.assertEqual(axis_size, xs.shape[0])
return dict(a=jnp.cos(xs)), dict(a=xs_batched["x"])
xy = dict(x=np.ones((8,), dtype=np.float32), y=np.zeros((8,), dtype=np.float32))
self._check_tracers_and_jaxprs(
jax.jit(jax.vmap(my_f)),
xy,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\"), result_paths=(\"[\'a\']\",)",
],
expected_tracer_debug_infos=[
"traced_for=custom_vmap, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\")",
"traced_for=jit, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\")"
])
def test_cond(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
def my_true_branch(a, b):
leaked_tracers.append(a)
@ -673,16 +761,42 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
# TODO(necula): some Jaxprs without debug info
'None',
'None'],
"None"],
expected_tracer_debug_infos=[
"traced_for=cond, fun=my_true_branch, arg_names=('a', 'b')",
"traced_for=cond, fun=my_false_branch, arg_names=('c', 'd')"
])
def test_switch(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
def my_branch0(x0):
leaked_tracers.append(x0)
return x0
def my_branch1(x1):
leaked_tracers.append(x1)
return x1
def my_branch2(x2):
leaked_tracers.append(x2)
return x2
return lax.switch(x, [my_branch0, my_branch1, my_branch2], x)
self._check_tracers_and_jaxprs(
jax.jit(my_f),
2,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
# TODO(necula): some Jaxprs without debug info
"None"],
expected_tracer_debug_infos=[
"traced_for=switch, fun=my_branch0, arg_names=('x0',)",
"traced_for=switch, fun=my_branch1, arg_names=('x1',)",
"traced_for=switch, fun=my_branch2, arg_names=('x2',)"
])
def test_grad_cond_with_remat(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x, y):
# The cond branches return two things, and only the first is needed
# in the residuals.
@ -711,12 +825,6 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=jit, fun=my_f, arg_names=('x', 'y'), result_paths=('',)",
# TODO(necula): some Jaxprs without debug info
'None',
'None',
'None',
'None',
'None',
'None',
'None'
],
expected_tracer_debug_infos=[
"traced_for=cond, fun=my_true_branch, arg_names=('a', 'b')",
@ -724,9 +832,46 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=checkpoint / remat, fun=my_g, arg_names=('x', 'y')"
])
def test_grad_scan(self):
# Based on control_flow_test:testScanHigherOrderDifferentiation
leaked_tracers: list[core.Tracer] = []
def f(c, a):
leaked_tracers.append(c)
d = 0.75
b = jnp.sin(c * jnp.sum(jnp.cos(d * a)))
c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a)))
return c, b
as_ = jnp.arange(6.).reshape((3, 2))
c = jnp.array(1, dtype=as_.dtype)
@jax.jit
def my_f(x, as_):
leaked_tracers.append(x)
return jax.remat(lambda *args: for_loop.scan(f, *args))(c, as_)
def the_grad(c, as_):
leaked_tracers.append(c)
_, pullback = jax.vjp(my_f, c, as_)
return pullback((c, np.arange(3, dtype=c.dtype)))
self._check_tracers_and_jaxprs(
jax.jit(the_grad),
c, as_,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=the_grad, arg_names=('c', 'as_'), result_paths=('[0]', '[1]')",
'None', # TODO(necula): some Jaxprs without debug info
],
expected_tracer_debug_infos=[
"traced_for=jit, fun=the_grad, arg_names=('c', 'as_')",
"traced_for=scan, fun=f, arg_names=('c', 'a')",
"traced_for=jit, fun=my_f, arg_names=('x', 'as_')",
'None', # TODO(necula): some missing debug info
])
def test_while_loop(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
def my_cond(a):
leaked_tracers.append(a)
@ -745,16 +890,35 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
# TODO(necula): some Jaxprs without debug info
'None',
'None'],
expected_tracer_debug_infos=[
"traced_for=while_cond, fun=my_cond, arg_names=('a',)",
"traced_for=while_loop, fun=my_body, arg_names=('b',)"
])
def test_scan(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
def my_scan_body(carry, inp):
leaked_tracers.append(carry)
return (carry + inp, carry)
return lax.scan(my_scan_body, 0, x)
self._check_tracers_and_jaxprs(
jax.jit(my_f),
np.arange(8, dtype=np.int32),
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('[0]', '[1]')",
# TODO(necula): some Jaxprs without debug info
'None'],
expected_tracer_debug_infos=[
"traced_for=scan, fun=my_scan_body, arg_names=('carry', 'inp')"
])
def test_eval_shape(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
leaked_tracers.append(x)
return x
@ -766,7 +930,6 @@ class DebugInfoTest(jtu.JaxTestCase):
def test_pmap(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
leaked_tracers.append(x)
return jnp.sin(x)
@ -785,18 +948,15 @@ class DebugInfoTest(jtu.JaxTestCase):
def test_pmap_of_grad(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
leaked_tracers.append(x)
return jnp.sin(x)
self._check_tracers_and_jaxprs(
jax.jit(jax.pmap(jax.grad(my_f))),
jax.pmap(jax.grad(my_f)),
np.ones((jax.device_count(),), dtype=np.float32),
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
# TODO(necula): missing debug_info
'None'
"traced_for=pmap, fun=my_f, arg_names=('x',), result_paths=('',)",
],
leaked_tracers=leaked_tracers,
expected_tracer_debug_infos=[
@ -807,7 +967,6 @@ class DebugInfoTest(jtu.JaxTestCase):
def test_remat(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
@jax.remat
def my_g(y):
@ -830,7 +989,6 @@ class DebugInfoTest(jtu.JaxTestCase):
def test_grad_remat(self):
leaked_tracers: list[core.Tracer] = []
def my_f(x):
@jax.remat
def my_g(y):
@ -846,15 +1004,167 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
# TODO(necula): some Jaxprs without debug info
'None',
'None'],
"None"],
expected_tracer_debug_infos=[
"traced_for=checkpoint / remat, fun=my_g, arg_names=('y',)"
])
def test_remat_saved_residuals(self):
@functools.partial(jax.remat,
static_argnums=(1,),
policy=lambda p, *_, **__: "mul" in str(p))
def my_f(x, y):
x = ad_checkpoint.checkpoint_name(x * x, "foo")
x = x * x
return x + y
res = saved_residuals(my_f, 3., 4.)
self.assertEqual(res[0][1], "from the argument x")
self.assertRegex(res[1][1], r"named 'foo' from .*debug_info_test.py:.*my_f")
def test_checkify_pmap_basic(self):
if len(jax.devices()) < 2:
self.skipTest("requires at least 2 devices")
leaked_tracers: list[core.Tracer] = []
@jax.pmap
def my_f(my_x):
leaked_tracers.append(my_x)
y1 = jnp.sin(1./my_x)
y2 = jnp.sin(my_x)
return (y1 + y2,)
self._check_tracers_and_jaxprs(
jax.jit(checkify.checkify(my_f, errors=checkify.nan_checks)),
np.arange(len(jax.devices()), dtype=np.float32),
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
# TODO(necula): this should not be pointing into the JAX internals
re.compile(r"traced_for=jit, fun=checked_fun at .*jax/_src/checkify.py:.*, arg_names=\(\'args\[0\]\',\)"),
re.compile(r"traced_for=jit, fun=argsort at .*numpy/lax_numpy.py:.*, arg_names=\('a',\), result_paths=\('',\)"),
"None", # TODO(necula): missing tracer debug info
],
expected_tracer_debug_infos=[
"traced_for=xla_pmap, fun=my_f, arg_names=('my_x',)",
],
check_lowering=False, # TODO(necula): warning during lowering
)
def test_custom_dce_static_argnums(self):
leaked_tracers: list[core.Tracer] = []
@functools.partial(jax.experimental.custom_dce.custom_dce,
static_argnums=(0,))
def my_g(f, x):
leaked_tracers.append(x)
return f(x), 10 * f(x)
@my_g.def_dce
def my_g_dce(f, used_outs, x): # note: static_argnums are always passed first
leaked_tracers.append(x)
self.assertTrue(callable(f))
return [2 * v if used else None
for used, v in zip(used_outs, my_g(f, x))]
def my_f(x):
return jnp.exp(x)
self._check_tracers_and_jaxprs(
jax.jit(lambda x: my_g(my_f, x)[0]),
0.,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=('x',), result_paths=('',)",
# TODO(necula): some Jaxprs without debug info
'None'],
expected_tracer_debug_infos=[
# TODO(necula): bad arg_names
"traced_for=custom_dce, fun=my_g, arg_names=('args[0]',)"
])
def test_custom_dce_consts(self):
leaked_tracers: list[core.Tracer] = []
@jax.experimental.custom_dce.custom_dce
def f(x):
return np.eye(1) * jnp.sin(x), jnp.cos(x)
@f.def_dce
def rule(used_outs, x):
return (
np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None,
jnp.sqrt(x) if used_outs[1] else None,
)
self._check_tracers_and_jaxprs(
jax.jit(lambda x: f(x)[0]),
np.array(1.1234),
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=('x',), result_paths=('',)",
# TODO(necula): some Jaxprs without debug info
'None'],
expected_tracer_debug_infos=[
# TODO(necula): bad arg_names
# "traced_for=custom_dce, fun=my_g, arg_names=('args[0]',)"
])
def test_custom_linear_solve_complex(self):
leaked_tracers: list[core.Tracer] = []
def solve(a, b):
leaked_tracers.append(a)
def solve(matvec, x):
leaked_tracers.append(x)
return jsp.linalg.solve(a, x)
def high_precision_dot(a, b):
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
def tr_solve(matvec, x):
return jsp.linalg.solve(a.T, x)
matvec = functools.partial(high_precision_dot, a)
return lax.custom_linear_solve(matvec, b, solve, tr_solve)
rng = self.rng()
a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2)
b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2)
self._check_tracers_and_jaxprs(
jax.jit(lambda a, b: jax.jvp(solve, (a, b), (a, b))),
a, b,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=('a', 'b'), result_paths=('[0]', '[1]')",
re.compile(r"traced_for=jit, fun=_solve at .*scipy/linalg.py:.*, arg_names=\('a', 'b'\), result_paths=\('',\)"),
re.compile(r"traced_for=jit, fun=solve at .*/linalg.py:.*, arg_names=\('a', 'b'\), result_paths=\('',\)"),
re.compile(r"traced_for=jit, fun=_lu_solve at .*/linalg.py:.*, arg_names=\('lu', 'permutation', 'b'\), result_paths=\('',\)"),
"None", # TODO(necula): there are missing jaxpr debug info
],
expected_tracer_debug_infos=[
# TODO(necula): we don't see any leaks from tr_solve?
"None", # TODO(necula): there are missing debug info
re.compile(r"traced_for=custom_linear_solve, fun=f at .*control_flow/solves.py:.*, arg_names=\('x',\)"),
])
def test_custom_root_errors(self):
leaked_tracers: list[core.Tracer] = []
def dummy_root_usage(x):
leaked_tracers.append(x)
def my_f(x):
return x - 3.
return lax.custom_root(my_f, 0., lambda my_f, x: x, lambda my_f, x: x)
self._check_tracers_and_jaxprs(
jax.jit(lambda x: jax.jvp(dummy_root_usage, (x,), (0.0,))),
0.,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=('x',), result_paths=('[0]', '[1]')",
"None", # TODO(necula): there are missing debug info
],
expected_tracer_debug_infos=[
"None", # TODO(necula): there are missing debug info
])
def test_pallas_call(self):
leaked_tracers: list[core.Tracer] = []
def my_kernel(x_ref, y_ref, o_ref):
leaked_tracers.append(x_ref)
o_ref[...] = x_ref[...] + y_ref[...]
@ -879,7 +1189,7 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
# TODO(necula): missing Jaxpr debug info
"None", "None", "None", "None"],
"None"],
expected_tracer_debug_infos=[
# TODO(necula): arg_names seem to be wrong
# One tracer from every index map
@ -887,7 +1197,58 @@ class DebugInfoTest(jtu.JaxTestCase):
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
"traced_for=pallas_call, fun=my_kernel, arg_names=('x_ref', 'y_ref', 'o_ref')",
])
],
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
)
def test_checkify_pallas_call(self):
leaked_tracers: list[core.Tracer] = []
def kernel(x_ref, y_ref):
leaked_tracers.append(x_ref)
y_ref[...] = jnp.log(x_ref[...])
def my_f(input):
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
pallas_call = pl.pallas_call(kernel,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call,
errors=checkify.nan_checks)
return checked_call(input)[1]
self._check_tracers_and_jaxprs(
jax.jit(my_f),
jnp.arange(4, dtype=jnp.float32) - 2,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=('input',), result_paths=('',)",
"None", # TODO(necula): missing tracer debug info
],
expected_tracer_debug_infos=[
"traced_for=pallas_call, fun=kernel, arg_names=('x_ref', 'y_ref')",
],
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
)
def test_composite(self):
leaked_tracers: list[core.Tracer] = []
scale = np.array([0.5, 0.4, 0.3], dtype=np.float32)
@functools.partial(lax.composite, name="my.consts")
def my_consts(x):
leaked_tracers.append(x)
return x / scale
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
self._check_tracers_and_jaxprs(
jax.jit(my_consts), x,
leaked_tracers=leaked_tracers,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_consts, arg_names=('x',), result_paths=('',)",
"None"
],
expected_tracer_debug_infos=[
"traced_for=composite, fun=my_consts, arg_names=('x',)"])
class EagerPmapMixin: