diff --git a/tests/BUILD b/tests/BUILD index 053fa3457..cd18370fc 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -46,6 +46,7 @@ jax_multiplatform_test( srcs = ["debug_info_test.py"], enable_configs = ["tpu_v3_2x2"], deps = [ + "//jax:experimental", "//jax:pallas", "//jax:pallas_gpu", "//jax:pallas_gpu_ops", diff --git a/tests/api_test.py b/tests/api_test.py index 616552864..918f81367 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5830,7 +5830,7 @@ class RematTest(jtu.JaxTestCase): self.assertEqual(jaxpr_text.count(' sin '), 3) self.assertEqual(jaxpr_text.count(' cos '), 3) # Six calls to dot_general in the backward pass because we save the primal - # matmuls and only compure the backward pass ones (two for each primal one). + # matmuls and only compute the backward pass ones (two for each primal one). self.assertEqual(jaxpr_text.count(' dot_'), 6) jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2, diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 67c8e2ea4..cb0c9d6ca 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -23,17 +23,27 @@ from typing import Any from absl.testing import absltest, parameterized import jax +from jax import ad_checkpoint from jax import lax + +import jax.custom_batching +import jax.custom_derivatives +import jax.custom_transpose +from jax.experimental import checkify +import jax.experimental.custom_dce +from jax.experimental import pallas as pl +import jax.numpy as jnp +import jax.scipy as jsp + from jax._src import api_util +from jax._src.ad_checkpoint import saved_residuals from jax._src import config from jax._src import core from jax._src import test_util as jtu from jax._src.compilation_cache import is_persistent_cache_enabled -import jax.custom_batching -import jax.custom_derivatives -import jax.custom_transpose -from jax.experimental import pallas as pl -import jax.numpy as jnp +from jax._src.lax.control_flow import for_loop + + import numpy as np config.parse_flags_with_absl() @@ -68,19 +78,25 @@ def _debug_info_to_string(dbg: api_util.TracingDebugInfo | core.JaxprDebugInfo | return res +@jtu.with_config(jax_mutable_array_checks=True) class DebugInfoTest(jtu.JaxTestCase): def _check_tracers_and_jaxprs(self, traceable: Any, *args, - expected_jaxpr_debug_infos: list[str], + expected_jaxpr_debug_infos: list[str | re.Pattern], leaked_tracers: list[core.Tracer] = [], - expected_tracer_debug_infos: list[str] = [], + expected_tracer_debug_infos: list[str | re.Pattern] = [], + check_lowering: bool = True, **kwargs): """Checks for expected debug info in all jaxprs, and in leaked tracers. The `traceable.trace(*args, **kwargs)` is traced to a Jaxpr, and the debug infos in the nested Jaxprs are first converted to strings using `_debug_info_to_string` and then compared against `expected_jaxpr_debug_infos`. + An element of `expected_jaxpr_debug_infos` can be a string, in which case + it is looked up by equality, or a `re.Pattern` (the result of `re.compile`) + in which case it is looked up by `.match()`. All elements of + `expected_jaxpr_debug_infos` must appear, and all Jaxprs must be matched. One way in which the debug info is used in JAX is for leaked tracer description, or for ConcretizationErrors. Optionally, @@ -91,17 +107,39 @@ class DebugInfoTest(jtu.JaxTestCase): traced = traceable.trace(*args, **kwargs) all_jaxprs = _collect_jaxprs(traced.jaxpr.jaxpr) - jaxprs_debug_infos = [_debug_info_to_string(j.debug_info) - for j in all_jaxprs] - self.assertEqual(expected_jaxpr_debug_infos, jaxprs_debug_infos) + found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info) + for j in all_jaxprs] + + self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos) self._check_tracers(leaked_tracers, expected_tracer_debug_infos) + # Run the lowering because this one exercises more code with debug_info + # TODO(necula): check the lowering + if check_lowering: + traced.lower() def _check_tracers(self, leaked_tracers: list[core.Tracer], - expected_tracer_debug_infos: list[str]): - leaked_tracer_debug_infos = [_debug_info_to_string(t._debug_info) if hasattr(t, "_debug_info") else "None" - for t in leaked_tracers] - self.assertEqual(expected_tracer_debug_infos, leaked_tracer_debug_infos) + expected_tracer_debug_infos: list[str | re.Pattern]): + found_leaked_tracer_debug_infos = [ + _debug_info_to_string(t._debug_info) if hasattr(t, "_debug_info") else "None" + for t in leaked_tracers] + self._check_matches(expected_tracer_debug_infos, found_leaked_tracer_debug_infos) + + def _check_matches(self, + expected: list[str | re.Pattern], + found: list[str]): + expected_and_found = set() + unexpected: set[str] = set() + for debug_info in found: + for exp_re in expected: + ok = exp_re.match(debug_info) if isinstance(exp_re, re.Pattern) else exp_re == debug_info + if ok: + expected_and_found.add(exp_re) + break + else: + unexpected.add(debug_info) + self.assertEmpty(unexpected) # found unexpected debug_info + self.assertEmpty([e for e in expected if e not in expected_and_found]) # expected element that was not found def test_debug_info_basic(self): def my_f(x, y, z, w): @@ -556,7 +594,6 @@ class DebugInfoTest(jtu.JaxTestCase): def test_simple_jit(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x_dict, y): leaked_tracers.append(x_dict["a"]) return dict(c=x_dict["a"] + x_dict["b"], d=y) @@ -572,9 +609,34 @@ class DebugInfoTest(jtu.JaxTestCase): 'traced_for=jit, fun=my_f, arg_names=("x_dict[\'a\']", "x_dict[\'b\']", \'y\')', ]) + def test_jit_with_static_argnums(self): + leaked_tracers: list[core.Tracer] = [] + @functools.partial(jax.jit, static_argnums=(1,)) + def my_f(a, d): + leaked_tracers.append(a) + return a + + def my_g(a, d=1): + leaked_tracers.append(a) + return my_f(a, d) + + self._check_tracers_and_jaxprs( + jax.jit(my_g), + 3, + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_g, arg_names=('a',), result_paths=('',)", + "traced_for=jit, fun=my_f, arg_names=('a',), result_paths=()" + ], + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_g, arg_names=('a',)", + # TODO(necula): bad arg name + "traced_for=jit, fun=my_f, arg_names=('args[0]',)" + ]) + + def test_nested_jit(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x, y): leaked_tracers.append(x) @@ -599,7 +661,6 @@ class DebugInfoTest(jtu.JaxTestCase): def test_vjp_of_nested_jit(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x, y): leaked_tracers.append(x) @@ -616,8 +677,7 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=('x', 'y', 'res_ct'), result_paths=('[0]', '[1]')", # TODO(necula): missing debug info - 'None', - 'None' + "None", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info @@ -627,7 +687,6 @@ class DebugInfoTest(jtu.JaxTestCase): def test_vmap_of_nested_jit(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x, y): leaked_tracers.append(x) @@ -652,9 +711,38 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=jit, fun=my_g, arg_names=('u', 'v')" ]) + def test_custom_vmap(self): + leaked_tracers: list[core.Tracer] = [] + @jax.custom_batching.custom_vmap + def my_f(xdict): + x = xdict["x"] + leaked_tracers.append(x) + return dict(a=jnp.sin(x)) + + @my_f.def_vmap + def my_rule(axis_size, in_batched, xys): + xs = xys["x"] + leaked_tracers.append(xs) + xs_batched, = in_batched + self.assertEqual(xs_batched["x"], True) + self.assertEqual(axis_size, xs.shape[0]) + return dict(a=jnp.cos(xs)), dict(a=xs_batched["x"]) + + xy = dict(x=np.ones((8,), dtype=np.float32), y=np.zeros((8,), dtype=np.float32)) + self._check_tracers_and_jaxprs( + jax.jit(jax.vmap(my_f)), + xy, + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\"), result_paths=(\"[\'a\']\",)", + ], + expected_tracer_debug_infos=[ + "traced_for=custom_vmap, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\")", + "traced_for=jit, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\")" + ]) + def test_cond(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x): def my_true_branch(a, b): leaked_tracers.append(a) @@ -673,16 +761,42 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)", # TODO(necula): some Jaxprs without debug info - 'None', - 'None'], + "None"], expected_tracer_debug_infos=[ "traced_for=cond, fun=my_true_branch, arg_names=('a', 'b')", "traced_for=cond, fun=my_false_branch, arg_names=('c', 'd')" ]) + def test_switch(self): + leaked_tracers: list[core.Tracer] = [] + def my_f(x): + def my_branch0(x0): + leaked_tracers.append(x0) + return x0 + def my_branch1(x1): + leaked_tracers.append(x1) + return x1 + def my_branch2(x2): + leaked_tracers.append(x2) + return x2 + return lax.switch(x, [my_branch0, my_branch1, my_branch2], x) + + self._check_tracers_and_jaxprs( + jax.jit(my_f), + 2, + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)", + # TODO(necula): some Jaxprs without debug info + "None"], + expected_tracer_debug_infos=[ + "traced_for=switch, fun=my_branch0, arg_names=('x0',)", + "traced_for=switch, fun=my_branch1, arg_names=('x1',)", + "traced_for=switch, fun=my_branch2, arg_names=('x2',)" + ]) + def test_grad_cond_with_remat(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x, y): # The cond branches return two things, and only the first is needed # in the residuals. @@ -711,12 +825,6 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=jit, fun=my_f, arg_names=('x', 'y'), result_paths=('',)", # TODO(necula): some Jaxprs without debug info 'None', - 'None', - 'None', - 'None', - 'None', - 'None', - 'None' ], expected_tracer_debug_infos=[ "traced_for=cond, fun=my_true_branch, arg_names=('a', 'b')", @@ -724,9 +832,46 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=checkpoint / remat, fun=my_g, arg_names=('x', 'y')" ]) + def test_grad_scan(self): + # Based on control_flow_test:testScanHigherOrderDifferentiation + leaked_tracers: list[core.Tracer] = [] + def f(c, a): + leaked_tracers.append(c) + d = 0.75 + b = jnp.sin(c * jnp.sum(jnp.cos(d * a))) + c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a))) + return c, b + + as_ = jnp.arange(6.).reshape((3, 2)) + c = jnp.array(1, dtype=as_.dtype) + + @jax.jit + def my_f(x, as_): + leaked_tracers.append(x) + return jax.remat(lambda *args: for_loop.scan(f, *args))(c, as_) + + def the_grad(c, as_): + leaked_tracers.append(c) + _, pullback = jax.vjp(my_f, c, as_) + return pullback((c, np.arange(3, dtype=c.dtype))) + + self._check_tracers_and_jaxprs( + jax.jit(the_grad), + c, as_, + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=the_grad, arg_names=('c', 'as_'), result_paths=('[0]', '[1]')", + 'None', # TODO(necula): some Jaxprs without debug info + ], + expected_tracer_debug_infos=[ + "traced_for=jit, fun=the_grad, arg_names=('c', 'as_')", + "traced_for=scan, fun=f, arg_names=('c', 'a')", + "traced_for=jit, fun=my_f, arg_names=('x', 'as_')", + 'None', # TODO(necula): some missing debug info + ]) + def test_while_loop(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x): def my_cond(a): leaked_tracers.append(a) @@ -745,16 +890,35 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)", # TODO(necula): some Jaxprs without debug info - 'None', 'None'], expected_tracer_debug_infos=[ "traced_for=while_cond, fun=my_cond, arg_names=('a',)", "traced_for=while_loop, fun=my_body, arg_names=('b',)" ]) + def test_scan(self): + leaked_tracers: list[core.Tracer] = [] + def my_f(x): + def my_scan_body(carry, inp): + leaked_tracers.append(carry) + return (carry + inp, carry) + + return lax.scan(my_scan_body, 0, x) + + self._check_tracers_and_jaxprs( + jax.jit(my_f), + np.arange(8, dtype=np.int32), + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('[0]', '[1]')", + # TODO(necula): some Jaxprs without debug info + 'None'], + expected_tracer_debug_infos=[ + "traced_for=scan, fun=my_scan_body, arg_names=('carry', 'inp')" + ]) + def test_eval_shape(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x): leaked_tracers.append(x) return x @@ -766,7 +930,6 @@ class DebugInfoTest(jtu.JaxTestCase): def test_pmap(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x): leaked_tracers.append(x) return jnp.sin(x) @@ -785,18 +948,15 @@ class DebugInfoTest(jtu.JaxTestCase): def test_pmap_of_grad(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x): leaked_tracers.append(x) return jnp.sin(x) self._check_tracers_and_jaxprs( - jax.jit(jax.pmap(jax.grad(my_f))), + jax.pmap(jax.grad(my_f)), np.ones((jax.device_count(),), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)", - # TODO(necula): missing debug_info - 'None' + "traced_for=pmap, fun=my_f, arg_names=('x',), result_paths=('',)", ], leaked_tracers=leaked_tracers, expected_tracer_debug_infos=[ @@ -807,7 +967,6 @@ class DebugInfoTest(jtu.JaxTestCase): def test_remat(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x): @jax.remat def my_g(y): @@ -830,7 +989,6 @@ class DebugInfoTest(jtu.JaxTestCase): def test_grad_remat(self): leaked_tracers: list[core.Tracer] = [] - def my_f(x): @jax.remat def my_g(y): @@ -846,15 +1004,167 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)", # TODO(necula): some Jaxprs without debug info - 'None', - 'None'], + "None"], expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=my_g, arg_names=('y',)" ]) + def test_remat_saved_residuals(self): + @functools.partial(jax.remat, + static_argnums=(1,), + policy=lambda p, *_, **__: "mul" in str(p)) + def my_f(x, y): + x = ad_checkpoint.checkpoint_name(x * x, "foo") + x = x * x + return x + y + + res = saved_residuals(my_f, 3., 4.) + self.assertEqual(res[0][1], "from the argument x") + self.assertRegex(res[1][1], r"named 'foo' from .*debug_info_test.py:.*my_f") + + def test_checkify_pmap_basic(self): + if len(jax.devices()) < 2: + self.skipTest("requires at least 2 devices") + leaked_tracers: list[core.Tracer] = [] + @jax.pmap + def my_f(my_x): + leaked_tracers.append(my_x) + y1 = jnp.sin(1./my_x) + y2 = jnp.sin(my_x) + return (y1 + y2,) + + self._check_tracers_and_jaxprs( + jax.jit(checkify.checkify(my_f, errors=checkify.nan_checks)), + np.arange(len(jax.devices()), dtype=np.float32), + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + # TODO(necula): this should not be pointing into the JAX internals + re.compile(r"traced_for=jit, fun=checked_fun at .*jax/_src/checkify.py:.*, arg_names=\(\'args\[0\]\',\)"), + re.compile(r"traced_for=jit, fun=argsort at .*numpy/lax_numpy.py:.*, arg_names=\('a',\), result_paths=\('',\)"), + "None", # TODO(necula): missing tracer debug info + ], + expected_tracer_debug_infos=[ + "traced_for=xla_pmap, fun=my_f, arg_names=('my_x',)", + ], + check_lowering=False, # TODO(necula): warning during lowering + ) + + def test_custom_dce_static_argnums(self): + leaked_tracers: list[core.Tracer] = [] + @functools.partial(jax.experimental.custom_dce.custom_dce, + static_argnums=(0,)) + def my_g(f, x): + leaked_tracers.append(x) + return f(x), 10 * f(x) + + @my_g.def_dce + def my_g_dce(f, used_outs, x): # note: static_argnums are always passed first + leaked_tracers.append(x) + self.assertTrue(callable(f)) + return [2 * v if used else None + for used, v in zip(used_outs, my_g(f, x))] + + def my_f(x): + return jnp.exp(x) + + self._check_tracers_and_jaxprs( + jax.jit(lambda x: my_g(my_f, x)[0]), + 0., + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=, arg_names=('x',), result_paths=('',)", + # TODO(necula): some Jaxprs without debug info + 'None'], + expected_tracer_debug_infos=[ + # TODO(necula): bad arg_names + "traced_for=custom_dce, fun=my_g, arg_names=('args[0]',)" + ]) + + def test_custom_dce_consts(self): + leaked_tracers: list[core.Tracer] = [] + @jax.experimental.custom_dce.custom_dce + def f(x): + return np.eye(1) * jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, + jnp.sqrt(x) if used_outs[1] else None, + ) + + self._check_tracers_and_jaxprs( + jax.jit(lambda x: f(x)[0]), + np.array(1.1234), + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=, arg_names=('x',), result_paths=('',)", + # TODO(necula): some Jaxprs without debug info + 'None'], + expected_tracer_debug_infos=[ + # TODO(necula): bad arg_names + # "traced_for=custom_dce, fun=my_g, arg_names=('args[0]',)" + ]) + + def test_custom_linear_solve_complex(self): + leaked_tracers: list[core.Tracer] = [] + def solve(a, b): + leaked_tracers.append(a) + def solve(matvec, x): + leaked_tracers.append(x) + return jsp.linalg.solve(a, x) + + def high_precision_dot(a, b): + return lax.dot(a, b, precision=lax.Precision.HIGHEST) + + def tr_solve(matvec, x): + return jsp.linalg.solve(a.T, x) + matvec = functools.partial(high_precision_dot, a) + return lax.custom_linear_solve(matvec, b, solve, tr_solve) + + rng = self.rng() + a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2) + b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2) + + self._check_tracers_and_jaxprs( + jax.jit(lambda a, b: jax.jvp(solve, (a, b), (a, b))), + a, b, + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=, arg_names=('a', 'b'), result_paths=('[0]', '[1]')", + re.compile(r"traced_for=jit, fun=_solve at .*scipy/linalg.py:.*, arg_names=\('a', 'b'\), result_paths=\('',\)"), + re.compile(r"traced_for=jit, fun=solve at .*/linalg.py:.*, arg_names=\('a', 'b'\), result_paths=\('',\)"), + re.compile(r"traced_for=jit, fun=_lu_solve at .*/linalg.py:.*, arg_names=\('lu', 'permutation', 'b'\), result_paths=\('',\)"), + "None", # TODO(necula): there are missing jaxpr debug info + ], + expected_tracer_debug_infos=[ + # TODO(necula): we don't see any leaks from tr_solve? + "None", # TODO(necula): there are missing debug info + re.compile(r"traced_for=custom_linear_solve, fun=f at .*control_flow/solves.py:.*, arg_names=\('x',\)"), + ]) + + def test_custom_root_errors(self): + leaked_tracers: list[core.Tracer] = [] + def dummy_root_usage(x): + leaked_tracers.append(x) + def my_f(x): + return x - 3. + return lax.custom_root(my_f, 0., lambda my_f, x: x, lambda my_f, x: x) + + self._check_tracers_and_jaxprs( + jax.jit(lambda x: jax.jvp(dummy_root_usage, (x,), (0.0,))), + 0., + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=, arg_names=('x',), result_paths=('[0]', '[1]')", + "None", # TODO(necula): there are missing debug info + ], + expected_tracer_debug_infos=[ + "None", # TODO(necula): there are missing debug info + ]) + def test_pallas_call(self): leaked_tracers: list[core.Tracer] = [] - def my_kernel(x_ref, y_ref, o_ref): leaked_tracers.append(x_ref) o_ref[...] = x_ref[...] + y_ref[...] @@ -879,7 +1189,7 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)", # TODO(necula): missing Jaxpr debug info - "None", "None", "None", "None"], + "None"], expected_tracer_debug_infos=[ # TODO(necula): arg_names seem to be wrong # One tracer from every index map @@ -887,7 +1197,58 @@ class DebugInfoTest(jtu.JaxTestCase): "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')", "traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')", "traced_for=pallas_call, fun=my_kernel, arg_names=('x_ref', 'y_ref', 'o_ref')", - ]) + ], + check_lowering=False, # We need interpret mode on CPU. TODO(necula) + ) + + def test_checkify_pallas_call(self): + leaked_tracers: list[core.Tracer] = [] + def kernel(x_ref, y_ref): + leaked_tracers.append(x_ref) + y_ref[...] = jnp.log(x_ref[...]) + + def my_f(input): + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = pl.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call, + errors=checkify.nan_checks) + return checked_call(input)[1] + + self._check_tracers_and_jaxprs( + jax.jit(my_f), + jnp.arange(4, dtype=jnp.float32) - 2, + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=('input',), result_paths=('',)", + "None", # TODO(necula): missing tracer debug info + ], + expected_tracer_debug_infos=[ + "traced_for=pallas_call, fun=kernel, arg_names=('x_ref', 'y_ref')", + ], + check_lowering=False, # We need interpret mode on CPU. TODO(necula) + ) + + def test_composite(self): + leaked_tracers: list[core.Tracer] = [] + scale = np.array([0.5, 0.4, 0.3], dtype=np.float32) + @functools.partial(lax.composite, name="my.consts") + def my_consts(x): + leaked_tracers.append(x) + return x / scale + + + x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) + + self._check_tracers_and_jaxprs( + jax.jit(my_consts), x, + leaked_tracers=leaked_tracers, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_consts, arg_names=('x',), result_paths=('',)", + "None" + ], + expected_tracer_debug_infos=[ + "traced_for=composite, fun=my_consts, arg_names=('x',)"]) class EagerPmapMixin: