mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[better_errors] Add more debug info test coverage
Try to cover the tracing of almost all JAX higher-order primitives. Some of the tests added show missing debug info, marked with TODO. Fixes will come separately. Had to expand the helper functions _check_tracers_and_jaxprs to use regular expressions for matching because some debug info still contains non-deterministic elements.
This commit is contained in:
parent
55efd4b225
commit
e4d5427d13
@ -46,6 +46,7 @@ jax_multiplatform_test(
|
||||
srcs = ["debug_info_test.py"],
|
||||
enable_configs = ["tpu_v3_2x2"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_gpu",
|
||||
"//jax:pallas_gpu_ops",
|
||||
|
@ -5830,7 +5830,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jaxpr_text.count(' sin '), 3)
|
||||
self.assertEqual(jaxpr_text.count(' cos '), 3)
|
||||
# Six calls to dot_general in the backward pass because we save the primal
|
||||
# matmuls and only compure the backward pass ones (two for each primal one).
|
||||
# matmuls and only compute the backward pass ones (two for each primal one).
|
||||
self.assertEqual(jaxpr_text.count(' dot_'), 6)
|
||||
|
||||
jtu.check_grads(api.jit(f), (jnp.ones((5, 5)),), order=2,
|
||||
|
@ -23,17 +23,27 @@ from typing import Any
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
from jax import lax
|
||||
|
||||
import jax.custom_batching
|
||||
import jax.custom_derivatives
|
||||
import jax.custom_transpose
|
||||
from jax.experimental import checkify
|
||||
import jax.experimental.custom_dce
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||
import jax.custom_batching
|
||||
import jax.custom_derivatives
|
||||
import jax.custom_transpose
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -68,19 +78,25 @@ def _debug_info_to_string(dbg: api_util.TracingDebugInfo | core.JaxprDebugInfo |
|
||||
return res
|
||||
|
||||
|
||||
@jtu.with_config(jax_mutable_array_checks=True)
|
||||
class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def _check_tracers_and_jaxprs(self, traceable: Any,
|
||||
*args,
|
||||
expected_jaxpr_debug_infos: list[str],
|
||||
expected_jaxpr_debug_infos: list[str | re.Pattern],
|
||||
leaked_tracers: list[core.Tracer] = [],
|
||||
expected_tracer_debug_infos: list[str] = [],
|
||||
expected_tracer_debug_infos: list[str | re.Pattern] = [],
|
||||
check_lowering: bool = True,
|
||||
**kwargs):
|
||||
"""Checks for expected debug info in all jaxprs, and in leaked tracers.
|
||||
|
||||
The `traceable.trace(*args, **kwargs)` is traced to a Jaxpr, and the
|
||||
debug infos in the nested Jaxprs are first converted to strings using
|
||||
`_debug_info_to_string` and then compared against `expected_jaxpr_debug_infos`.
|
||||
An element of `expected_jaxpr_debug_infos` can be a string, in which case
|
||||
it is looked up by equality, or a `re.Pattern` (the result of `re.compile`)
|
||||
in which case it is looked up by `.match()`. All elements of
|
||||
`expected_jaxpr_debug_infos` must appear, and all Jaxprs must be matched.
|
||||
|
||||
One way in which the debug info is used in JAX is for leaked tracer
|
||||
description, or for ConcretizationErrors. Optionally,
|
||||
@ -91,17 +107,39 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
traced = traceable.trace(*args, **kwargs)
|
||||
all_jaxprs = _collect_jaxprs(traced.jaxpr.jaxpr)
|
||||
|
||||
jaxprs_debug_infos = [_debug_info_to_string(j.debug_info)
|
||||
for j in all_jaxprs]
|
||||
self.assertEqual(expected_jaxpr_debug_infos, jaxprs_debug_infos)
|
||||
found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info)
|
||||
for j in all_jaxprs]
|
||||
|
||||
self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos)
|
||||
self._check_tracers(leaked_tracers, expected_tracer_debug_infos)
|
||||
# Run the lowering because this one exercises more code with debug_info
|
||||
# TODO(necula): check the lowering
|
||||
if check_lowering:
|
||||
traced.lower()
|
||||
|
||||
def _check_tracers(self,
|
||||
leaked_tracers: list[core.Tracer],
|
||||
expected_tracer_debug_infos: list[str]):
|
||||
leaked_tracer_debug_infos = [_debug_info_to_string(t._debug_info) if hasattr(t, "_debug_info") else "None"
|
||||
for t in leaked_tracers]
|
||||
self.assertEqual(expected_tracer_debug_infos, leaked_tracer_debug_infos)
|
||||
expected_tracer_debug_infos: list[str | re.Pattern]):
|
||||
found_leaked_tracer_debug_infos = [
|
||||
_debug_info_to_string(t._debug_info) if hasattr(t, "_debug_info") else "None"
|
||||
for t in leaked_tracers]
|
||||
self._check_matches(expected_tracer_debug_infos, found_leaked_tracer_debug_infos)
|
||||
|
||||
def _check_matches(self,
|
||||
expected: list[str | re.Pattern],
|
||||
found: list[str]):
|
||||
expected_and_found = set()
|
||||
unexpected: set[str] = set()
|
||||
for debug_info in found:
|
||||
for exp_re in expected:
|
||||
ok = exp_re.match(debug_info) if isinstance(exp_re, re.Pattern) else exp_re == debug_info
|
||||
if ok:
|
||||
expected_and_found.add(exp_re)
|
||||
break
|
||||
else:
|
||||
unexpected.add(debug_info)
|
||||
self.assertEmpty(unexpected) # found unexpected debug_info
|
||||
self.assertEmpty([e for e in expected if e not in expected_and_found]) # expected element that was not found
|
||||
|
||||
def test_debug_info_basic(self):
|
||||
def my_f(x, y, z, w):
|
||||
@ -556,7 +594,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_simple_jit(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x_dict, y):
|
||||
leaked_tracers.append(x_dict["a"])
|
||||
return dict(c=x_dict["a"] + x_dict["b"], d=y)
|
||||
@ -572,9 +609,34 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
'traced_for=jit, fun=my_f, arg_names=("x_dict[\'a\']", "x_dict[\'b\']", \'y\')',
|
||||
])
|
||||
|
||||
def test_jit_with_static_argnums(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
@functools.partial(jax.jit, static_argnums=(1,))
|
||||
def my_f(a, d):
|
||||
leaked_tracers.append(a)
|
||||
return a
|
||||
|
||||
def my_g(a, d=1):
|
||||
leaked_tracers.append(a)
|
||||
return my_f(a, d)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(my_g),
|
||||
3,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_g, arg_names=('a',), result_paths=('',)",
|
||||
"traced_for=jit, fun=my_f, arg_names=('a',), result_paths=()"
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=my_g, arg_names=('a',)",
|
||||
# TODO(necula): bad arg name
|
||||
"traced_for=jit, fun=my_f, arg_names=('args[0]',)"
|
||||
])
|
||||
|
||||
|
||||
def test_nested_jit(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x, y):
|
||||
leaked_tracers.append(x)
|
||||
|
||||
@ -599,7 +661,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_vjp_of_nested_jit(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x, y):
|
||||
leaked_tracers.append(x)
|
||||
|
||||
@ -616,8 +677,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=('x', 'y', 'res_ct'), result_paths=('[0]', '[1]')",
|
||||
# TODO(necula): missing debug info
|
||||
'None',
|
||||
'None'
|
||||
"None",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): missing debug info
|
||||
@ -627,7 +687,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_vmap_of_nested_jit(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x, y):
|
||||
leaked_tracers.append(x)
|
||||
|
||||
@ -652,9 +711,38 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=jit, fun=my_g, arg_names=('u', 'v')"
|
||||
])
|
||||
|
||||
def test_custom_vmap(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
@jax.custom_batching.custom_vmap
|
||||
def my_f(xdict):
|
||||
x = xdict["x"]
|
||||
leaked_tracers.append(x)
|
||||
return dict(a=jnp.sin(x))
|
||||
|
||||
@my_f.def_vmap
|
||||
def my_rule(axis_size, in_batched, xys):
|
||||
xs = xys["x"]
|
||||
leaked_tracers.append(xs)
|
||||
xs_batched, = in_batched
|
||||
self.assertEqual(xs_batched["x"], True)
|
||||
self.assertEqual(axis_size, xs.shape[0])
|
||||
return dict(a=jnp.cos(xs)), dict(a=xs_batched["x"])
|
||||
|
||||
xy = dict(x=np.ones((8,), dtype=np.float32), y=np.zeros((8,), dtype=np.float32))
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(jax.vmap(my_f)),
|
||||
xy,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\"), result_paths=(\"[\'a\']\",)",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_vmap, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\")",
|
||||
"traced_for=jit, fun=my_f, arg_names=(\"xdict[\'x\']\", \"xdict[\'y\']\")"
|
||||
])
|
||||
|
||||
def test_cond(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x):
|
||||
def my_true_branch(a, b):
|
||||
leaked_tracers.append(a)
|
||||
@ -673,16 +761,42 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None',
|
||||
'None'],
|
||||
"None"],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=cond, fun=my_true_branch, arg_names=('a', 'b')",
|
||||
"traced_for=cond, fun=my_false_branch, arg_names=('c', 'd')"
|
||||
])
|
||||
|
||||
def test_switch(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
def my_f(x):
|
||||
def my_branch0(x0):
|
||||
leaked_tracers.append(x0)
|
||||
return x0
|
||||
def my_branch1(x1):
|
||||
leaked_tracers.append(x1)
|
||||
return x1
|
||||
def my_branch2(x2):
|
||||
leaked_tracers.append(x2)
|
||||
return x2
|
||||
return lax.switch(x, [my_branch0, my_branch1, my_branch2], x)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(my_f),
|
||||
2,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
"None"],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=switch, fun=my_branch0, arg_names=('x0',)",
|
||||
"traced_for=switch, fun=my_branch1, arg_names=('x1',)",
|
||||
"traced_for=switch, fun=my_branch2, arg_names=('x2',)"
|
||||
])
|
||||
|
||||
def test_grad_cond_with_remat(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x, y):
|
||||
# The cond branches return two things, and only the first is needed
|
||||
# in the residuals.
|
||||
@ -711,12 +825,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=jit, fun=my_f, arg_names=('x', 'y'), result_paths=('',)",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None',
|
||||
'None',
|
||||
'None',
|
||||
'None',
|
||||
'None',
|
||||
'None',
|
||||
'None'
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=cond, fun=my_true_branch, arg_names=('a', 'b')",
|
||||
@ -724,9 +832,46 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=('x', 'y')"
|
||||
])
|
||||
|
||||
def test_grad_scan(self):
|
||||
# Based on control_flow_test:testScanHigherOrderDifferentiation
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
def f(c, a):
|
||||
leaked_tracers.append(c)
|
||||
d = 0.75
|
||||
b = jnp.sin(c * jnp.sum(jnp.cos(d * a)))
|
||||
c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a)))
|
||||
return c, b
|
||||
|
||||
as_ = jnp.arange(6.).reshape((3, 2))
|
||||
c = jnp.array(1, dtype=as_.dtype)
|
||||
|
||||
@jax.jit
|
||||
def my_f(x, as_):
|
||||
leaked_tracers.append(x)
|
||||
return jax.remat(lambda *args: for_loop.scan(f, *args))(c, as_)
|
||||
|
||||
def the_grad(c, as_):
|
||||
leaked_tracers.append(c)
|
||||
_, pullback = jax.vjp(my_f, c, as_)
|
||||
return pullback((c, np.arange(3, dtype=c.dtype)))
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(the_grad),
|
||||
c, as_,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=the_grad, arg_names=('c', 'as_'), result_paths=('[0]', '[1]')",
|
||||
'None', # TODO(necula): some Jaxprs without debug info
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=the_grad, arg_names=('c', 'as_')",
|
||||
"traced_for=scan, fun=f, arg_names=('c', 'a')",
|
||||
"traced_for=jit, fun=my_f, arg_names=('x', 'as_')",
|
||||
'None', # TODO(necula): some missing debug info
|
||||
])
|
||||
|
||||
def test_while_loop(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x):
|
||||
def my_cond(a):
|
||||
leaked_tracers.append(a)
|
||||
@ -745,16 +890,35 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None',
|
||||
'None'],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=while_cond, fun=my_cond, arg_names=('a',)",
|
||||
"traced_for=while_loop, fun=my_body, arg_names=('b',)"
|
||||
])
|
||||
|
||||
def test_scan(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
def my_f(x):
|
||||
def my_scan_body(carry, inp):
|
||||
leaked_tracers.append(carry)
|
||||
return (carry + inp, carry)
|
||||
|
||||
return lax.scan(my_scan_body, 0, x)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(my_f),
|
||||
np.arange(8, dtype=np.int32),
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('[0]', '[1]')",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=scan, fun=my_scan_body, arg_names=('carry', 'inp')"
|
||||
])
|
||||
|
||||
def test_eval_shape(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x):
|
||||
leaked_tracers.append(x)
|
||||
return x
|
||||
@ -766,7 +930,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_pmap(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x):
|
||||
leaked_tracers.append(x)
|
||||
return jnp.sin(x)
|
||||
@ -785,18 +948,15 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_pmap_of_grad(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x):
|
||||
leaked_tracers.append(x)
|
||||
return jnp.sin(x)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(jax.pmap(jax.grad(my_f))),
|
||||
jax.pmap(jax.grad(my_f)),
|
||||
np.ones((jax.device_count(),), dtype=np.float32),
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): missing debug_info
|
||||
'None'
|
||||
"traced_for=pmap, fun=my_f, arg_names=('x',), result_paths=('',)",
|
||||
],
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_tracer_debug_infos=[
|
||||
@ -807,7 +967,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_remat(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x):
|
||||
@jax.remat
|
||||
def my_g(y):
|
||||
@ -830,7 +989,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_remat(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_f(x):
|
||||
@jax.remat
|
||||
def my_g(y):
|
||||
@ -846,15 +1004,167 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None',
|
||||
'None'],
|
||||
"None"],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=('y',)"
|
||||
])
|
||||
|
||||
def test_remat_saved_residuals(self):
|
||||
@functools.partial(jax.remat,
|
||||
static_argnums=(1,),
|
||||
policy=lambda p, *_, **__: "mul" in str(p))
|
||||
def my_f(x, y):
|
||||
x = ad_checkpoint.checkpoint_name(x * x, "foo")
|
||||
x = x * x
|
||||
return x + y
|
||||
|
||||
res = saved_residuals(my_f, 3., 4.)
|
||||
self.assertEqual(res[0][1], "from the argument x")
|
||||
self.assertRegex(res[1][1], r"named 'foo' from .*debug_info_test.py:.*my_f")
|
||||
|
||||
def test_checkify_pmap_basic(self):
|
||||
if len(jax.devices()) < 2:
|
||||
self.skipTest("requires at least 2 devices")
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
@jax.pmap
|
||||
def my_f(my_x):
|
||||
leaked_tracers.append(my_x)
|
||||
y1 = jnp.sin(1./my_x)
|
||||
y2 = jnp.sin(my_x)
|
||||
return (y1 + y2,)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(checkify.checkify(my_f, errors=checkify.nan_checks)),
|
||||
np.arange(len(jax.devices()), dtype=np.float32),
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): this should not be pointing into the JAX internals
|
||||
re.compile(r"traced_for=jit, fun=checked_fun at .*jax/_src/checkify.py:.*, arg_names=\(\'args\[0\]\',\)"),
|
||||
re.compile(r"traced_for=jit, fun=argsort at .*numpy/lax_numpy.py:.*, arg_names=\('a',\), result_paths=\('',\)"),
|
||||
"None", # TODO(necula): missing tracer debug info
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=xla_pmap, fun=my_f, arg_names=('my_x',)",
|
||||
],
|
||||
check_lowering=False, # TODO(necula): warning during lowering
|
||||
)
|
||||
|
||||
def test_custom_dce_static_argnums(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
@functools.partial(jax.experimental.custom_dce.custom_dce,
|
||||
static_argnums=(0,))
|
||||
def my_g(f, x):
|
||||
leaked_tracers.append(x)
|
||||
return f(x), 10 * f(x)
|
||||
|
||||
@my_g.def_dce
|
||||
def my_g_dce(f, used_outs, x): # note: static_argnums are always passed first
|
||||
leaked_tracers.append(x)
|
||||
self.assertTrue(callable(f))
|
||||
return [2 * v if used else None
|
||||
for used, v in zip(used_outs, my_g(f, x))]
|
||||
|
||||
def my_f(x):
|
||||
return jnp.exp(x)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(lambda x: my_g(my_f, x)[0]),
|
||||
0.,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): bad arg_names
|
||||
"traced_for=custom_dce, fun=my_g, arg_names=('args[0]',)"
|
||||
])
|
||||
|
||||
def test_custom_dce_consts(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
@jax.experimental.custom_dce.custom_dce
|
||||
def f(x):
|
||||
return np.eye(1) * jnp.sin(x), jnp.cos(x)
|
||||
|
||||
@f.def_dce
|
||||
def rule(used_outs, x):
|
||||
return (
|
||||
np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None,
|
||||
jnp.sqrt(x) if used_outs[1] else None,
|
||||
)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(lambda x: f(x)[0]),
|
||||
np.array(1.1234),
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None'],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): bad arg_names
|
||||
# "traced_for=custom_dce, fun=my_g, arg_names=('args[0]',)"
|
||||
])
|
||||
|
||||
def test_custom_linear_solve_complex(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
def solve(a, b):
|
||||
leaked_tracers.append(a)
|
||||
def solve(matvec, x):
|
||||
leaked_tracers.append(x)
|
||||
return jsp.linalg.solve(a, x)
|
||||
|
||||
def high_precision_dot(a, b):
|
||||
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
|
||||
|
||||
def tr_solve(matvec, x):
|
||||
return jsp.linalg.solve(a.T, x)
|
||||
matvec = functools.partial(high_precision_dot, a)
|
||||
return lax.custom_linear_solve(matvec, b, solve, tr_solve)
|
||||
|
||||
rng = self.rng()
|
||||
a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2)
|
||||
b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(lambda a, b: jax.jvp(solve, (a, b), (a, b))),
|
||||
a, b,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=('a', 'b'), result_paths=('[0]', '[1]')",
|
||||
re.compile(r"traced_for=jit, fun=_solve at .*scipy/linalg.py:.*, arg_names=\('a', 'b'\), result_paths=\('',\)"),
|
||||
re.compile(r"traced_for=jit, fun=solve at .*/linalg.py:.*, arg_names=\('a', 'b'\), result_paths=\('',\)"),
|
||||
re.compile(r"traced_for=jit, fun=_lu_solve at .*/linalg.py:.*, arg_names=\('lu', 'permutation', 'b'\), result_paths=\('',\)"),
|
||||
"None", # TODO(necula): there are missing jaxpr debug info
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): we don't see any leaks from tr_solve?
|
||||
"None", # TODO(necula): there are missing debug info
|
||||
re.compile(r"traced_for=custom_linear_solve, fun=f at .*control_flow/solves.py:.*, arg_names=\('x',\)"),
|
||||
])
|
||||
|
||||
def test_custom_root_errors(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
def dummy_root_usage(x):
|
||||
leaked_tracers.append(x)
|
||||
def my_f(x):
|
||||
return x - 3.
|
||||
return lax.custom_root(my_f, 0., lambda my_f, x: x, lambda my_f, x: x)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(lambda x: jax.jvp(dummy_root_usage, (x,), (0.0,))),
|
||||
0.,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=('x',), result_paths=('[0]', '[1]')",
|
||||
"None", # TODO(necula): there are missing debug info
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"None", # TODO(necula): there are missing debug info
|
||||
])
|
||||
|
||||
def test_pallas_call(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
|
||||
def my_kernel(x_ref, y_ref, o_ref):
|
||||
leaked_tracers.append(x_ref)
|
||||
o_ref[...] = x_ref[...] + y_ref[...]
|
||||
@ -879,7 +1189,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('x',), result_paths=('',)",
|
||||
# TODO(necula): missing Jaxpr debug info
|
||||
"None", "None", "None", "None"],
|
||||
"None"],
|
||||
expected_tracer_debug_infos=[
|
||||
# TODO(necula): arg_names seem to be wrong
|
||||
# One tracer from every index map
|
||||
@ -887,7 +1197,58 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
|
||||
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=('i[0]', 'i[1]')",
|
||||
"traced_for=pallas_call, fun=my_kernel, arg_names=('x_ref', 'y_ref', 'o_ref')",
|
||||
])
|
||||
],
|
||||
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
|
||||
)
|
||||
|
||||
def test_checkify_pallas_call(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
def kernel(x_ref, y_ref):
|
||||
leaked_tracers.append(x_ref)
|
||||
y_ref[...] = jnp.log(x_ref[...])
|
||||
|
||||
def my_f(input):
|
||||
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
|
||||
pallas_call = pl.pallas_call(kernel,
|
||||
out_shape=out_shape)
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
errors=checkify.nan_checks)
|
||||
return checked_call(input)[1]
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(my_f),
|
||||
jnp.arange(4, dtype=jnp.float32) - 2,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=('input',), result_paths=('',)",
|
||||
"None", # TODO(necula): missing tracer debug info
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=pallas_call, fun=kernel, arg_names=('x_ref', 'y_ref')",
|
||||
],
|
||||
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
|
||||
)
|
||||
|
||||
def test_composite(self):
|
||||
leaked_tracers: list[core.Tracer] = []
|
||||
scale = np.array([0.5, 0.4, 0.3], dtype=np.float32)
|
||||
@functools.partial(lax.composite, name="my.consts")
|
||||
def my_consts(x):
|
||||
leaked_tracers.append(x)
|
||||
return x / scale
|
||||
|
||||
|
||||
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(my_consts), x,
|
||||
leaked_tracers=leaked_tracers,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_consts, arg_names=('x',), result_paths=('',)",
|
||||
"None"
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=composite, fun=my_consts, arg_names=('x',)"])
|
||||
|
||||
|
||||
class EagerPmapMixin:
|
||||
|
Loading…
x
Reference in New Issue
Block a user