# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import functools import operator import re from typing import Any import unittest from absl.testing import absltest import jax from jax import ad_checkpoint from jax import lax import jax.custom_batching import jax.custom_derivatives from jax.experimental import checkify import jax.experimental.custom_dce from jax.experimental import pallas as pl from jax.experimental.shard_map import shard_map import jax.numpy as jnp import jax.scipy as jsp from jax.sharding import Mesh from jax.sharding import PartitionSpec as P from jax._src import api_util from jax._src.ad_checkpoint import saved_residuals from jax._src import config from jax._src import core from jax._src import custom_transpose from jax._src import test_util as jtu from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lax.control_flow import for_loop from jax._src.interpreters import mlir import numpy as np config.parse_flags_with_absl() jtu.request_cpu_devices(8) def _collect_jaxprs(jaxpr: core.Jaxpr, acc: list[core.Jaxpr] | None = None) -> list[core.Jaxpr]: """Collect all Jaxprs in a depth-first order.""" if acc is None: acc = [] acc.append(jaxpr) for e in jaxpr.eqns: # Take first the block mapping Jaxprs if e.primitive.name == "pallas_call": # For pallas_call, extract also jaxprs inside the grid_mapping mapping = e.params["grid_mapping"] for bm in mapping.block_mappings: _collect_jaxprs(bm.index_map_jaxpr.jaxpr, acc) for sj in core.jaxprs_in_params(e.params): _collect_jaxprs(sj, acc) return acc def _debug_info_to_string(dbg: core.DebugInfo) -> list[str]: # Strip the absolute path and the line number but check that it references # this file (to catch errors when the source info points in JAX internals) func_src_info = re.sub(r"^(\S+)( at .*.debug_info_test.py:\d+)?", "\\1", dbg.func_src_info) arg_names_str = ",".join([str(a) for a in dbg.arg_names]) res = f"traced_for={dbg.traced_for}, fun={func_src_info}, arg_names={arg_names_str}" if isinstance(dbg.result_paths, tuple): res += f", result_paths={','.join(dbg.result_paths)}" elif dbg.result_paths is None: res += ", result_paths=" return res class TracerSpy: """Use to inspect tracers. We can `append` tracers from tracing contexts to this object. We collect the tracer, along with the error message we get when we try to concretize it. This is meant to simulate errors like concretization or leaking. """ tracers: list[tuple[core.Tracer, Exception]] def __init__(self): self.tracers = [] def append(self, t: Any) -> None: if isinstance(t, core.Tracer): try: # We plan to do boolean conversion and catch the exception, but this works # only for scalars t_scalar = t while t_scalar.shape: t_scalar = t_scalar[0] if t_scalar: pass assert False, t_scalar except Exception as e: self.tracers.append((t, e)) @jtu.with_config(jax_mutable_array_checks=True) class DebugInfoTest(jtu.JaxTestCase): def _check_tracers_and_jaxprs(self, traceable: Any, *args, expected_jaxpr_debug_infos: list[str | re.Pattern], tracer_spy: TracerSpy | None = None, expected_tracer_debug_infos: list[str | re.Pattern] = [], check_lowering: bool = True, expected_lowering_lines: list[str | re.Pattern] = [], **kwargs) -> None: """Checks the expected debug info in all jaxprs, in spied tracers, and StableHLO. `traceable` will be traced as `traceable.trace(*args, **kwargs)` if it has a `trace` method (for jit), or will be called as `traceable(*args, **kwargs)` otherwise (for eager). We collect all the nested Jaxprs, either from the result of `trace`, or by capturing all the lowered jaxprs in eager mode. The debug infos in the nested Jaxprs are first converted to strings using `_debug_info_to_string` and then compared against `expected_jaxpr_debug_infos`. During this conversion, we strip occurences of this test file name and a line number (e.g., .*/debug_info_test.py:56) An element of `expected_jaxpr_debug_infos` can be a string, in which case it is compared by equality, or a `re.Pattern` (the result of `re.compile`) in which case it is compared by `.match()`. All elements of `expected_jaxpr_debug_infos` must appear, and all Jaxprs must be matched. Optionally, we can pass a TracerSpy object into which we have `append`ed tracers from the execution of `traceable`. E.g., if we do `tracer_spy.append(a)`, where `a` is an argument of a `jit` function, we expect to see debugging info "traced_for=jit, fun=my_f, arg_names=a,b, from a". These debugging infos are compared with `expected_tracer_debug_infos`. Finally, if we pass `expected_lowering_lines` then we are looking for those matches in the StableHLO MLIR modules that are lowered. """ if hasattr(traceable, "trace"): traced = traceable.trace(*args, **kwargs) all_jaxprs = _collect_jaxprs(traced.jaxpr.jaxpr) else: # Just run the function and collect the Jaxprs and modules that are # lowered traced = None with jtu.collect_lowered_jaxprs() as collection: traceable(*args, **kwargs) all_jaxprs = [] for jaxpr, _ in collection: all_jaxprs.extend(_collect_jaxprs(jaxpr.jaxpr)) found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info) for j in all_jaxprs] self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos, "Jaxprs debug_infos") # JAXPRS found_tracer_debug_infos = [] if tracer_spy is not None: for t, exc in tracer_spy.tracers: if hasattr(t, "_debug_info"): t_debug_info = _debug_info_to_string(t._debug_info) msg = str(exc) m = re.match(r".* while tracing the function (.+) for ([^.]+)\.", msg, re.DOTALL) if m is None: found_tracer_debug_infos.append( f"{t_debug_info}, from None") else: self.assertIsNotNone(m, msg) self.assertEqual(t._debug_info.func_src_info, m.group(1)) self.assertEqual(t._debug_info.traced_for, m.group(2)) m = re.match(r".* depends on the value of the argument ([^\n]+)\.", msg, re.DOTALL) found_tracer_debug_infos.append( f"{t_debug_info}, from {m.group(1) if m else None}") else: found_tracer_debug_infos.append("None") self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos, "Tracer debug_infos") # INSPECTED TRACERS if not check_lowering: return # Collect all the lines in all the MLIR modules mlir_modules_lines = [] if traced is not None: mlir_modules_lines.extend( traced.lower().as_text("stablehlo", debug_info=True).split("\n")) else: for _, mod in collection: mlir_modules_lines.extend( mlir.module_to_string(mod, enable_debug_info=True).split("\n")) self._check_matches(expected_lowering_lines, mlir_modules_lines, "MLIR module lines", report_found_unexpected=False) def _check_matches(self, expected: list[str | re.Pattern], found: list[str], what: str, report_found_unexpected: bool = True) -> None: expected_and_found: set[str | re.Pattern] = set() found_and_expected: set[str] = set() for exp_re in expected: for found_line in found: ok = exp_re.match(found_line) if isinstance(exp_re, re.Pattern) else exp_re == found_line if ok: expected_and_found.add(exp_re) found_and_expected.add(found_line) found_and_unexpected = set(found) - found_and_expected all_found = "\n ".join(found) if report_found_unexpected and found_and_unexpected: unexp_str = "\n ".join(found_and_unexpected) msg = f"Found unexpected {what}:\n {unexp_str}\nAll found {what}:\n {all_found}" self.assertTrue(False, msg) if expected_not_found := {e for e in expected if e not in expected_and_found}: exp_str = "\n ".join([str(e) for e in expected_not_found]) msg = f"Expected but not found in {what}:\n {exp_str}\nAll found {what}:\n {all_found}" self.assertTrue(False, msg) def test_debug_info_basic(self): def my_f(x, y, z, w): pass dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertEqual(dbg.func_name, "my_f") self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) self.assertIsNone(dbg.result_paths) def test_debug_info_arg_passed_as_kwarg(self): def my_f(x, y, z): pass dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3)) self.assertEqual(dbg.arg_names, ("x", "y", "z")) def test_debug_info_pytrees(self): def my_f(x_tree, *, y_tree): pass dbg = api_util.debug_info("jit", my_f, ((1, 2),), dict(y_tree=dict(z=3, w=4))) self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]", "y_tree['w']", "y_tree['z']")) def test_debug_info_with_statics(self): def my_f(x, y, *, z, w): pass dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4), static_argnums=(1,), static_argnames=("w",)) self.assertEqual(dbg.arg_names, ("x", "z")) def test_debug_info_with_pytrees_and_statics(self): def my_f(x, y, *, z, w): pass dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)), dict(z=(3, 4), w=(5, 6)), static_argnums=(1,), static_argnames=("w",)) self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) def test_debug_info_too_many_args(self): def my_f(x): pass dbg = api_util.debug_info("jit", my_f, (1, 2, 3), dict(z=3)) self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) def test_debug_info_no_source_info_built_in(self): # built-in function "int" does not have an inspect.Signature dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") self.assertEqual(dbg.arg_names, ("args[0]",)) def test_debug_info_lambda(self): # built-in function "int" does not have an inspect.Signature dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {}) self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") self.assertEqual(dbg.arg_names, ("my_arg",)) def test_debug_info_save_wrapped_fun_source_info(self): def wrapper(x, y): return x dbg = api_util.debug_info("test", wrapper, (1, 2), {}) self.assertEqual("wrapper", dbg.func_name) api_util.save_wrapped_fun_sourceinfo(wrapper, lambda x, y: x) dbg = api_util.debug_info("test", wrapper, (1, 2), {}) self.assertEqual("", dbg.func_name) def other_f(): pass dbg_other = api_util.debug_info("test other", other_f, (), {}) api_util.save_wrapped_fun_sourceinfo(wrapper, dbg_other) dbg = api_util.debug_info("test", wrapper, (1, 2), {}) self.assertEqual("other_f", dbg.func_name) self.assertEqual("test", dbg.traced_for) def test_debug_info_no_source_info_not_callable(self): # built-in function "int" does not have an inspect.Signature dbg = api_util.debug_info("jit", False, (1,), {}) self.assertEqual(dbg.func_src_info, "") self.assertEqual(dbg.arg_names, ("args[0]",)) def test_debug_info_no_source_info_callable(self): class Foo: x: int def __call__(self, y): return self.x + y dbg = api_util.debug_info("jit", Foo(), (1,), {}) self.assertRegex(dbg.func_src_info, "") self.assertEqual(dbg.arg_names, ("y",)) def test_debug_info_no_source_info_callable_with_repr_errors(self): class Foo: x: int def __call__(self, y): return self.x + y def __repr__(self): raise NotImplementedError dbg = api_util.debug_info("jit", Foo(), (1,), {}) self.assertRegex(dbg.func_src_info, "") self.assertEqual(dbg.arg_names, ("y",)) def helper_save_tracer(self, x): self._saved_tracer = x return x def test_jit_lower_arg_names_with_error1(self): def f(x): return jnp.sqrt(x ** 2) + 1. x = jnp.array(1, dtype=int) x_f32 = x.astype(jnp.float32) x_i32 = x.astype(jnp.int32) f_exe = jax.jit(f).lower(x_f32).compile() self.assertRaisesRegex( TypeError, r"Argument types differ .*" r"The mismatches are:\n" r"Argument 'x' compiled with.*float32.*and called with.*int32.*", lambda: f_exe(x_i32)) def test_jit_lower_arg_names_with_error2(self): def f(x): return x err_str = ("Error interpreting argument to .* as an abstract array. The problematic " "value is of type .* and was passed to the function at path x.") with self.assertRaisesRegex(TypeError, err_str): jax.jit(f)("foo") # Jax type objects aren't valid data arguments. with self.assertRaisesRegex(TypeError, err_str): jax.jit(f)(jnp.int32) @jtu.thread_unsafe_test() # logging is not thread-safe def test_arg_names_cache_miss_explanations(self): @jax.jit def f(x, y): return jnp.sin(x) * y['hi'] x = jnp.float32(1.) y = {'hi': jnp.arange(3., dtype='float32')} expected_log_len = 1 if not is_persistent_cache_enabled() else 3 # print on first miss, not on hit with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: f(x, y) f(x, y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] self.assertIn('TRACING CACHE MISS', msg) self.assertIn('never seen function', msg) # shape change y_ = {'hi': jnp.arange(4, dtype='float32')} with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: f(x, y_) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] self.assertIn('never seen input type signature', msg) self.assertIn('closest seen input type signature has 1 mismatches', msg) self.assertIn('seen f32[3], but now given f32[4]', msg) # weak type change (assuming no x64) if not config.enable_x64.value: with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: f(1., y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] self.assertIn('weak_type=True', msg) self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) # kwarg change with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: f(1, y=y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) # tracing config change with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: with jax.numpy_rank_promotion('warn'): f(x, y) # depending on the backend, we may or may not get persistent cache warnings self.assertTrue(1 <= len(cm.output) <= expected_log_len) msg = cm.output[0] self.assertIn("tracing context doesn't match", msg) @jtu.thread_unsafe_test() # logging is not thread-safe def test_arg_names_cache_miss_explanations_new_function_in_loop(self): @jax.jit def f(x, y): return jnp.sin(x) * y['hi'] x = jnp.float32(1.) with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: for _ in range(2): jax.jit(lambda x: 2 * x)(3) if is_persistent_cache_enabled(): # number of warnings depends on the backend self.assertTrue(4 <= len(cm.output) <= 6) msg = cm.output[3] self.assertIn('another function defined on the same line', msg) else: self.assertLen(cm.output, 2) _, msg = cm.output self.assertIn('another function defined on the same line', msg) @jtu.thread_unsafe_test() # logging is not thread-safe @unittest.skip("Test fails, probably due to caching") def test_arg_names_cache_miss_explanations_unpacks_transforms(self): # Tests that the explain_tracing_cache_miss() function does not throw an # error when unpacking `transforms` with a length greater than 3. @jax.jit def f(key): return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32) with config.explain_cache_misses(True): with self.assertLogs(level="WARNING") as cm: f(jax.random.key(seed=123)) if is_persistent_cache_enabled(): # 5 warnings from tracing cache, 5-10 from persistent cache depending on # the backend self.assertTrue(10 <= len(cm.output) <= 15) self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output)) else: self.assertLen(cm.output, 5) for msg in cm.output: self.assertIn("TRACING CACHE MISS", msg) def test_arg_names_cache_miss_explanations_no_source_info(self): # ``operator.add`` is a built-in function and does not have source info. with config.explain_cache_misses(True): jax.jit(operator.add)(42, 24) def test_concrete_error_because_arg_unary(self): @jax.jit def f(x): if x > 0: return x else: return 0 msg = r"on the value of the argument x" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f(1) def test_concrete_error_because_arg_binary(self): @jax.jit def f(x, y): if x > y: return x else: return y msg = r"on the values of the arguments x and y" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f(1, 2) def test_concrete_error_because_arg_ternary(self): @jax.jit def f(x, y, z): if x > z: return x else: return y msg = r"on the values of the arguments x and z" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f(1, 2, 3) with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f(1, 2, z=3) with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f(1, y=2, z=3) def test_concrete_error_because_arg_varargs(self): @jax.jit def f(*args): x, y, z = args if x > z: return x else: return y msg = r"on the values of the arguments args" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f(1, 2, 3) def test_concrete_error_because_arg_kwargs(self): @jax.jit def f(**kwargs): x, y, z = kwargs['x'], kwargs['y'], kwargs['z'] if x > z: return x else: return y msg = r"on the values of the arguments kwargs" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f(x=1, y=2, z=3) def test_concrete_error_because_arg_pytree(self): @jax.jit def f(xy, z): x, y = xy if x > 0: return x else: return y msg = r"on the value of the argument xy" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f((1, 2), z=3) def test_concrete_error_because_const(self): @jax.jit def f(): assert jnp.add(1, 1) > 0 msg = "on these lines" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f() def test_concrete_error_because_const_2(self): @jax.jit def f(): result = sum(jnp.add(1, 1) for _ in range(6)) assert result > 0 msg = "Additional originating lines are not shown." with self.assertRaisesRegex(core.ConcretizationTypeError, msg): f() def test_concrete_error_with_nested_call(self): @jax.jit def f(x, y): if y: return x @jax.jit def g(x): return f(x, True) msg = r"on the value of the argument y" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): g(1) def test_remat_concrete_error(self): @jax.remat # no static_argnums or concrete def g(x): if x > 0: return lax.sin(x) else: return lax.cos(x) with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"): g(3.) @functools.partial(jax.remat, static_argnums=(0,)) # using static_argnums but... def g(x): if x > 0: # jnp operations still get staged! return lax.sin(x) else: return lax.cos(x) with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"): g(jnp.array(3.)) # But don't raise an error mentioning static_argnums here: @jax.remat def g(x): jax.jit(lambda: 0 if jnp.add(1, 1) else 0)() return lax.sin(x) try: g(jnp.array(3.)) except core.ConcretizationTypeError as e: msg = str(e) self.assertNotIn('static_argnums', msg) def test_simple_jit(self): tracer_spy = TracerSpy() def my_f(x_dict, y): tracer_spy.append(x_dict["a"]) return dict(c=x_dict["a"] + x_dict["b"], d=y) self._check_tracers_and_jaxprs( jax.jit(my_f), dict(a=1, b=2), 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=result['c'],result['d']" ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, from x_dict['a']", ]) def test_jit_with_static_argnums(self): tracer_spy = TracerSpy() @functools.partial(jax.jit, static_argnums=(1,)) def my_f(a, d): tracer_spy.append(a) return a def my_g(b, d=1): tracer_spy.append(b) return my_f(b, d) self._check_tracers_and_jaxprs( jax.jit(my_g), 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): result_paths? "traced_for=jit, fun=my_f, arg_names=a, result_paths=", "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_g, arg_names=b, from b", "traced_for=jit, fun=my_f, arg_names=a, from a", ]) def test_jit_arg_names(self): tracer_spy = TracerSpy() def f(x, y, *args, **kwargs): # args[0] is dead tracer_spy.append(kwargs["w"]) tracer_spy.append(args[0]) return y['hi'] + args[1] + sum(kwargs.values()) self._check_tracers_and_jaxprs( jax.jit(f), {"ho": 1.}, {"hi": 2.}, 3., 4., z=5., w=6., expected_jaxpr_debug_infos=[ "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from kwargs['w']", "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from args[0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"y\['hi'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['w'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(\"kwargs\['z'\]\"\)"), re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ] ) def test_jit_arg_names_static_argnums(self): tracer_spy = TracerSpy() def my_f(x, y, z, *args, **kwargs): # z and args[0] and kwargs["w"] are dead # x and args[2] are static tracer_spy.append(kwargs["w"]) tracer_spy.append(x[0]) return x[0] + y["hi"] + args[1] + args[2] + kwargs["t"] self._check_tracers_and_jaxprs( jax.jit(my_f, static_argnums=(0, 5)), (1.,), {"hi": 2.}, 3., 4., 5., 6., # x, y, z, args[0], args[1], args[2] t=11., w=12., # kwargs expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"y\['hi'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['t'\]\"\)"), re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) def test_jit_arg_names_static_argnames(self): tracer_spy = TracerSpy() def f(x, y, *args, **kwargs): # x and args[0] and kwargs["z"] are dead # kwargs[a] is static tracer_spy.append(x[0]) return y['hi'] + args[1] + kwargs['a'] + kwargs['b'] + kwargs['w'] self._check_tracers_and_jaxprs( jax.jit(f, static_argnames=("a",)), (1.,), {'hi': 2.}, 3., 4., # x, y, args[0], args[1] z=5., w=6., a=7., b=8., # kwargs expected_jaxpr_debug_infos=[ "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], from x[0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"y\['hi'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['b'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(\"kwargs\['w'\]\"\)"), re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) def test_jit_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} self._check_tracers_and_jaxprs( jax.jit(f), 1., (2.,), [3.], expected_jaxpr_debug_infos=[ "traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=result['a'],result['b'][0][0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"x\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"y\[0\]\"\)"), re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['a'\]\"\}"), re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['b'\]\[0\]\[0\]\"\}"), ]) def test_nested_jit(self): tracer_spy = TracerSpy() def my_f(x, y): tracer_spy.append(x) def my_g(u, v): tracer_spy.append(u) return dict(c=u * v, d=v) return jax.jit(my_g)(y, x)["c"] self._check_tracers_and_jaxprs( jax.jit(my_f), 2, 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, from x", "traced_for=jit, fun=my_g, arg_names=u,v, from u" ]) def test_nested_jit_with_const_and_unused_args(self): def my_f(x, y): # y is unused def my_g(u, v): # v is unused return v + np.ones(v.shape, v.dtype) return x + jax.jit(my_g)(y, x) x = y = np.ones((8,), dtype=np.float32) self._check_tracers_and_jaxprs( jax.jit(my_f), x, y, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"), re.compile(r".*call @my_g\(%arg.\) : \(tensor<8xf..>\)"), ] ) def test_jvp_of_jit(self): tracer_spy = TracerSpy() def f(x, y, z): tracer_spy.append(x) return {'a': x, 'b': [y]} self._check_tracers_and_jaxprs( lambda x, y, z: jax.jvp(jax.jit(f), (x, y, z), (x, y, z)), jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ # TODO(necula): arg_names, result_paths "traced_for=jit, fun=f, arg_names=,,,, result_paths=,,,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=f, arg_names=x,y[0],z[0], from x", ], expected_lowering_lines=[ # TODO(necula): missing arg_names re.compile(r".*func.func public @main\(%arg0: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(unknown\)"), # TODO(necula): missing result names re.compile(r".*func.func public @main\(.*-> .*tensor {jax.result_info = \"\"}"), ]) def test_vjp_of_jit(self): # TODO(b/398208230): Re-enable this test after fixing. self.skipTest("Enable this after figuring out why it's failing") tracer_spy = TracerSpy() def my_f(x, y, z): tracer_spy.append(y[0]) return {'a': x * y[0], 'b': [y]} self._check_tracers_and_jaxprs( lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])), jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=result", re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths=result"), # TODO(necula): arg_names? result_paths? "traced_for=jit, fun=my_f, arg_names=,,,, result_paths=['a'],['b'][0][0]", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y[0],z[0], from y[0]", ], expected_lowering_lines=[ # TODO(necula): missing arg_names re.compile(r".*func.func public @main\(%arg0: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(unknown\)"), # TODO(necula): result_paths? re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) def test_vjp_of_nested_jit(self): tracer_spy = TracerSpy() def my_f(x, y): tracer_spy.append(x) def my_g(u, v): tracer_spy.append(u) return dict(c=u * v, d=v) return jax.jit(my_g)(y, x)["c"] self._check_tracers_and_jaxprs( jax.jit(lambda x, y, res_ct: jax.vjp(my_f, x, y)[1](res_ct)), 2., 3., 0.3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", # TODO(necula): result_paths "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", # TODO(necula): arg_names "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=," if config.use_direct_linearize.value else "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", "traced_for=jit, fun=my_g, arg_names=u,v, from u" ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"x\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"y\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"res_ct\"\)"), re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[0\]\"}"), re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[1\]\"}"), ]) def test_vjp_remat(self): tracer_spy = TracerSpy() def apply_fn(inp): tracer_spy.append(inp) def to_remat(x): tracer_spy.append(x) return jax.nn.relu(x * x) fn = jax.checkpoint(to_remat) return jax.vjp(fn, inp) self._check_tracers_and_jaxprs( jax.jit(apply_fn), 2., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): what are these flat_index components? "traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=result[0],result[1][0][0][0][0][0]", re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"), re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"), ], expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x", "traced_for=jit, fun=apply_fn, arg_names=inp, from inp", ]) def test_custom_jvp(self): tracer_spy = TracerSpy() @jax.custom_jvp def my_fun(x, y, c=1.): tracer_spy.append(y) return c * (x + y) def my_jvp(primals, tangents): x, y, c = primals t_x, t_y, t_c = tangents tracer_spy.append(t_y) return my_fun(x, y, c), t_c my_fun.defjvp(my_jvp) def top_f(x, y): return jnp.square(my_fun(x, y, c=2.)).sum() self._check_tracers_and_jaxprs( jax.jit(lambda a: jax.jvp(top_f, (a, a), (jnp.ones_like(a), jnp.ones_like(a)))), 42., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=a, result_paths=result[0],result[1]", "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=result", ], expected_tracer_debug_infos=[ # TODO(necula): from None? "traced_for=jit, fun=, arg_names=a, from None", "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y", ]) def test_custom_jvp_nondiff_args(self): tracer_spy = TracerSpy() def top_f(xy): tracer_spy.append(xy[0]) @functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) def my_g(h, xy): x, y = xy tracer_spy.append(x) return h(x) @my_g.defjvp def my_g_jvp(h, primals, tangents): (x, y), = primals (xt, yt), = tangents tracer_spy.append(xt) return my_g(h, (x, y)), 2. * xt h = lambda y: xy[0] + y # capture x return my_g(h, xy) self._check_tracers_and_jaxprs( jax.jit(lambda a, b: jax.jvp(top_f, ((a, b),), ((jnp.ones_like(a), jnp.ones_like(b)),))), 42., 43., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): arg_names "traced_for=jit, fun=, arg_names=,a,b, result_paths=result[0],result[1]", "traced_for=custom_jvp fun, fun=my_g, arg_names=,xy[0],xy[1], result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]", # TODO(necula): from None "traced_for=jit, fun=, arg_names=a,b, from None", "None", # TODO(necula): None ]) def test_custom_vjp(self): tracer_spy = TracerSpy() @jax.custom_vjp def my_f(x): tracer_spy.append(x["a"]) return {"b": jnp.sin(x["a"])} def my_f_fwd(x): tracer_spy.append(x["a"]) return my_f(x), {"r": jnp.cos(x["a"])} def my_f_bwd(res, g): tracer_spy.append(g["b"]) cos_x = res["r"] return ({"a": 2 * cos_x * g["b"]},) my_f.defvjp(my_f_fwd, my_f_bwd) def to_diff(x): return my_f(x)["b"] self._check_tracers_and_jaxprs( jax.jit(jax.grad(to_diff)), {"a" : 3.}, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=result['a']", "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=result['b']", ], expected_tracer_debug_infos=[ "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']", # TODO(necula): from None? "traced_for=jit, fun=to_diff, arg_names=x['a'], from None", "traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']", ]) def test_custom_vjp_nondiff_args(self): tracer_spy = TracerSpy() @functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) def app(f, xy): tracer_spy.append(xy[0]) return f(xy) def app_fwd(f, xy): tracer_spy.append(xy[0]) return app(f, xy), jnp.cos(xy[0]) def app_rev(f, cos_x0, g): tracer_spy.append(cos_x0) tracer_spy.append(g) return ((cos_x0 * g, cos_x0),) app.defvjp(app_fwd, app_rev) self._check_tracers_and_jaxprs( jax.jit(jax.grad(lambda xy: app(lambda x: 2 * x[0], xy))), (3., 3.), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=xy[0],xy[1], result_paths=result[0],result[1]", "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=, arg_names=xy[0],xy[1], from xy[0]", "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]", # TODO(necula): from None "traced_for=jit, fun=, arg_names=xy[0],xy[1], from None", ]) def test_custom_transpose(self): # Helpers from api_test.py class _custom_transpose: def __init__(self, out_types, fun): self.out_types = out_types self.fun = custom_transpose.custom_transpose(fun) def __getattr__(self, name): return getattr(self.fun, name) def __call__(self, *args): return self.fun(self.out_types, *args) def custom_transpose_with_example_out(example_out): return functools.partial( _custom_transpose, jax.tree.map( lambda x: core.get_aval(x).to_tangent_aval(), example_out)) tracer_spy = TracerSpy() def my_f_with_cond(i, x): def my_f(x): tracer_spy.append(x) @custom_transpose_with_example_out(jnp.ones(2)) def fn(r, x): tracer_spy.append(r) tracer_spy.append(x["c"]) return dict(b=x["c"] / r) @fn.def_transpose def fn_tp(r, t): tracer_spy.append(r) return dict(c=2 * t / r) return x["c"] + fn(jnp.ones(2) * 3., x) return lax.cond(i > 0, my_f, lambda x: x["c"], dict(c=x)) x = jnp.ones(2) * 6. self._check_tracers_and_jaxprs( jax.jit(lambda x: jax.linear_transpose(functools.partial(my_f_with_cond, 7.), x)), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=result", "traced_for=cond, fun=, arg_names=x['c'], result_paths=result", "traced_for=jit, fun=, arg_names=x, result_paths=result[0][0][0],result[0][0][1]", ], expected_tracer_debug_infos=[ "traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from r", "traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from x['c']", "traced_for=custom_transpose transpose_fun, fun=fn_tp, arg_names=r,t, from r", ]) def test_linear_call(self): tracer_spy = TracerSpy() def my_f(x, y): tracer_spy.append(y) def fn(r, x): tracer_spy.append(r) tracer_spy.append(x["c"]) return dict(b=x["c"] * r) def fn_tp(r, t): tracer_spy.append(t["b"]) return dict(c=t["b"] * r) return dict(a=x["c"] + jax.custom_derivatives.linear_call(fn, fn_tp, y, x)["b"]) f1 = lambda x: my_f(x, jnp.ones(2) * 3.) x = jnp.ones(2) * 6. self._check_tracers_and_jaxprs( jax.jit(lambda x: jax.linear_transpose(f1, dict(c=x))(dict(a=x))), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result[0]['c']", "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']", "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']", ], expected_tracer_debug_infos=[ # TODO(necula): from None? "traced_for=jit, fun=, arg_names=x, from None", "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], from r", "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], from x['c']", "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], from t['c']", ]), def test_custom_vmap(self): tracer_spy = TracerSpy() @jax.custom_batching.custom_vmap def my_f(xdict): x = xdict["x"] tracer_spy.append(x) return dict(a=jnp.sin(x)) @my_f.def_vmap def my_rule(axis_size, in_batched, xys): xs = xys["x"] tracer_spy.append(xs) xs_batched, = in_batched self.assertEqual(xs_batched["x"], True) self.assertEqual(axis_size, xs.shape[0]) return dict(a=jnp.cos(xs)), dict(a=xs_batched["x"]) xy = dict(x=np.ones((8,), dtype=np.float32), y=np.zeros((8,), dtype=np.float32)) self._check_tracers_and_jaxprs( jax.jit(jax.vmap(my_f)), xy, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=result['a']", ], expected_tracer_debug_infos=[ "traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']", "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']" ]) def test_cond(self): tracer_spy = TracerSpy() def my_f(x): def my_true_branch(a, b): tracer_spy.append(a) return a + b def my_false_branch(c, d): tracer_spy.append(c) return c - d return lax.cond(x >= 0, my_true_branch, my_false_branch, x, x) self._check_tracers_and_jaxprs( jax.jit(my_f), 0, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result", "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=cond, fun=my_true_branch, arg_names=a,b, from a", "traced_for=cond, fun=my_false_branch, arg_names=c,d, from c" ]) def test_switch(self): tracer_spy = TracerSpy() def my_f(x): def my_branch0(x0): tracer_spy.append(x0) return x0 def my_branch1(x1): tracer_spy.append(x1) return x1 + 1 def my_branch2(x2): tracer_spy.append(x2) return x2 + 2 return lax.switch(x, [my_branch0, my_branch1, my_branch2], x) self._check_tracers_and_jaxprs( jax.jit(my_f), 2, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", "traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=result", "traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=result", "traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=switch, fun=my_branch0, arg_names=x0, from x0", "traced_for=switch, fun=my_branch1, arg_names=x1, from x1", "traced_for=switch, fun=my_branch2, arg_names=x2, from x2" ]) def test_grad_cond_with_remat(self): tracer_spy = TracerSpy() def my_f(x, y): # The cond branches return two things, and only the first is needed # in the residuals. def my_true_branch(a, b): tracer_spy.append(a) return (a + 1, a + b) def my_false_branch(c, d): tracer_spy.append(c) return (c - 1, c - d) def my_g(x, y): # x1 does not depend on y x1, y1 = lax.cond(x >= 0, my_true_branch, my_false_branch, x, y) tracer_spy.append(x1) return x1, y1 x2, y2 = jax.remat(my_g)(x, y) return y2 + lax.sin(x2) self._check_tracers_and_jaxprs( jax.jit(jax.grad(my_f)), 1., 2., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", # TODO(necula): arg_names? result_paths? "traced_for=cond, fun=my_true_branch, arg_names=, result_paths=,", "traced_for=cond, fun=my_false_branch, arg_names=, result_paths=,", "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result[0],result[1]", "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result[0],result[1]", "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=,", ], expected_tracer_debug_infos=[ "traced_for=cond, fun=my_true_branch, arg_names=a,b, from a", "traced_for=cond, fun=my_false_branch, arg_names=c,d, from c", # TODO(necula): from None "traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None", ]) def test_grad_scan(self): # Based on control_flow_test:testScanHigherOrderDifferentiation tracer_spy = TracerSpy() def f(c, a): tracer_spy.append(c) d = 0.75 b = jnp.sin(c * jnp.sum(jnp.cos(d * a))) c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a))) return c, b as_ = jnp.arange(6.).reshape((3, 2)) c = jnp.array(1, dtype=as_.dtype) @jax.jit def my_f(x, as_): tracer_spy.append(x) def to_remat(a, b): return for_loop.scan(f, a, b) return jax.remat(to_remat)(c, as_) def the_grad(c, as_): tracer_spy.append(c) _, pullback = jax.vjp(my_f, c, as_) return pullback((c, np.arange(3, dtype=c.dtype))) self._check_tracers_and_jaxprs( jax.jit(the_grad), c, as_, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", # TODO(necula): arg names, bad result paths "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" if config.use_direct_linearize.value else "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", "traced_for=scan, fun=f, arg_names=c,a, from c", "traced_for=jit, fun=my_f, arg_names=x,as_, from x", # TODO(necula): arg_names, and "from x" "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], from refs[0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"c\"\)"), re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"), re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"result\[0\]\""), re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""), # TODO(necula): unnamed function? re.compile(r".*func.func private @None"), ]) def test_while_loop(self): tracer_spy = TracerSpy() def my_f(x): def my_cond(a): tracer_spy.append(a) return a <= 8 def my_body(b): tracer_spy.append(b) return b + 1 return lax.while_loop(my_cond, my_body, x) self._check_tracers_and_jaxprs( jax.jit(my_f), 0, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", "traced_for=while_body, fun=my_body, arg_names=b, result_paths=result", "traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=while_cond, fun=my_cond, arg_names=a, from a", "traced_for=while_body, fun=my_body, arg_names=b, from b", ]) def test_fori(self): # See https://github.com/jax-ml/jax/issues/23637 tracer_spy = TracerSpy() def my_body(_, c): tracer_spy.append(c) return 0. self._check_tracers_and_jaxprs( jax.jit(lambda x: jax.lax.fori_loop(0, 5, my_body, x)), 3., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result", # TODO(necula): bad arg_names, result_paths "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=result[0][0],result[0][1]", ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], from loop_carry[1]", ] ) tracer_spy = TracerSpy() # When the ubound is not a constant, we use a while_loop self._check_tracers_and_jaxprs( jax.jit(lambda ub, x: jax.lax.fori_loop(0, ub, my_body, x)), 5, 3., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=ub,x, result_paths=result", re.compile(r"traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths="), # TODO(necula): arg_names and result_paths are not right "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=result[0],result[1],result[2]", ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], from loop_carry[2]", ]) def test_scan(self): tracer_spy = TracerSpy() def my_f(x): def my_scan_body(carry, inp): tracer_spy.append(carry) return (carry + inp, carry) return lax.scan(my_scan_body, 0, x) self._check_tracers_and_jaxprs( jax.jit(my_f), np.arange(8, dtype=np.int32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result[0],result[1]", "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=result[0],result[1]", ], expected_tracer_debug_infos=[ "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry" ]) def test_eval_shape(self): tracer_spy = TracerSpy() def my_f(x): tracer_spy.append(x) return x _ = jax.eval_shape(my_f, 0) self._check_tracers_and_jaxprs( lambda: jax.eval_shape(my_f, 0), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, from x"], ) def test_vmap_of_nested_jit(self): tracer_spy = TracerSpy() def my_f(x, y): tracer_spy.append(x) def my_g(u, v): tracer_spy.append(u) return dict(c=u * v, d=v) return jax.jit(my_g)(y, x)["c"] self._check_tracers_and_jaxprs( jax.jit(jax.vmap(my_f)), np.ones((8,), dtype=np.float32), np.zeros((8,), dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", "traced_for=jit, fun=my_g, arg_names=u,v, from u" ]) def test_pmap(self): tracer_spy = TracerSpy() def my_f(x): tracer_spy.append(x) return jnp.sin(x) self._check_tracers_and_jaxprs( jax.pmap(my_f), np.ones((jax.device_count(),), dtype=np.float32), expected_jaxpr_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result" ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=x, from x" ], ) def test_pmap_with_arg_and_result_names(self): tracer_spy = TracerSpy() x = np.ones((jax.device_count(),), dtype=np.float32) def my_f(x, y, *args, a, **kwargs): # y and kwargs[c] is dead tracer_spy.append(args[1]) s = x + a + args[1] + kwargs["d"] return dict(u=s, v=x) self._check_tracers_and_jaxprs( jax.pmap(my_f, static_broadcasted_argnums=(0,)), 1., x, x, x, # x, y, args[0], args[1] d=x, a=x, b=x, # kwargs expected_jaxpr_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", ], expected_lowering_lines=[ # TODO(necula): we did not DCE y? re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), ] ) def test_pmap_of_grad(self): tracer_spy = TracerSpy() def my_f(x): tracer_spy.append(x) return jnp.sin(x) self._check_tracers_and_jaxprs( jax.pmap(jax.grad(my_f)), np.ones((jax.device_count(),), dtype=np.float32), expected_jaxpr_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info 'None' ], ) def test_jvp_pmap_eager(self): tracer_spy = TracerSpy() def my_f(x, y, *args): # y is dead, x is static broadcasted tracer_spy.append(args[1]) s = x + args[1] return dict(u=s, v=x) x = jnp.ones((jax.device_count(), 1), dtype=np.float32) x_tan = jnp.full_like(x, .1) self._check_tracers_and_jaxprs( lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x, x, x), (x_tan, x_tan, x_tan, x_tan)), x, x_tan, expected_jaxpr_debug_infos=[ # TODO(necula): why this? re.compile(r'traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*'), "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=result['u'],result['v']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info "None" ], ) @jtu.ignore_warning(category=UserWarning, message=".* jitted function .* includes a pmap") def test_jvp_pmap(self): tracer_spy = TracerSpy() def my_f(x, y): tracer_spy.append(x) return jnp.sin(x) + y x = np.ones((jax.device_count(), 1), dtype=np.float32) x_tan = np.full_like(x, .1) self._check_tracers_and_jaxprs( jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))), x, x_tan, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x,x_tan, result_paths=result[0],result[1]", "traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info "None" ], ) def test_hessian(self): tracer_spy = TracerSpy() def my_f(x): tracer_spy.append(x) return jnp.square(x).mean() x = jax.random.uniform(jax.random.key(0), shape=(8, 4)) self._check_tracers_and_jaxprs( jax.jit(jax.hessian(jax.jit(my_f))), x, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): arg_names and result_paths? "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," if config.use_direct_linearize.value else "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, from x", ], ) (x).block_until_ready() def test_remat(self): tracer_spy = TracerSpy() def my_f(x): @jax.remat def my_g(y): tracer_spy.append(y) return lax.sin(y) return my_g(x) self._check_tracers_and_jaxprs( jax.jit(my_f), 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): missing result_paths "traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y" ]) def test_grad_remat(self): tracer_spy = TracerSpy() def my_f(x): @jax.remat def my_g(y): tracer_spy.append(y) return lax.sin(lax.sin(y)) return my_g(my_g(x)) self._check_tracers_and_jaxprs( jax.jit(jax.grad(my_f)), 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): arg_names? result_paths? "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y", ]) def test_remat_shard_map(self): tracer_spy = TracerSpy() if len(jax.devices()) < 2: self.skipTest("requires at least 2 devices") # this tests remat-of-shmap mesh = Mesh(np.array(jax.devices()[:2]), ('x',)) # check param updating is handled @jax.remat @functools.partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def my_f(x): tracer_spy.append(x) return jnp.sin(jnp.sin(x)) self._check_tracers_and_jaxprs( jax.jit(jax.grad(lambda x: my_f(x).sum())), jnp.arange(2, dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): arg_names, result_paths "traced_for=jit, fun=, arg_names=x, result_paths=result", "traced_for=checkpoint / remat, fun=my_f, arg_names=,, result_paths=", "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result", "traced_for=shard_map, fun=my_f, arg_names=,, result_paths=", ], expected_tracer_debug_infos=[ "None" # TODO(necula): missing ]) def test_remat_saved_residuals(self): @functools.partial(jax.remat, static_argnums=(1,), policy=lambda p, *_, **__: "mul" in str(p)) def my_f(x, y): x = ad_checkpoint.checkpoint_name(x * x, "foo") x = x * x return x + y res = saved_residuals(my_f, 3., 4.) self.assertEqual(res[0][1], "from the argument x") self.assertRegex(res[1][1], r"named 'foo' from .*debug_info_test.py:.*my_f") def test_checkify_pmap_basic(self): if len(jax.devices()) < 2: self.skipTest("requires at least 2 devices") tracer_spy = TracerSpy() @jax.pmap def my_f(my_x): tracer_spy.append(my_x) y1 = jnp.sin(1./my_x) y2 = jnp.sin(my_x) return (y1 + y2,) self._check_tracers_and_jaxprs( jax.jit(checkify.checkify(my_f, errors=checkify.nan_checks)), np.arange(len(jax.devices()), dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): this should not be pointing into the JAX internals re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"), re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"), "traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=result[0]", ], expected_tracer_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=my_x, from my_x", ], check_lowering=False, # TODO(necula): warning during lowering ) def test_custom_dce_static_argnums(self): tracer_spy = TracerSpy() @functools.partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) def my_g(f, x): tracer_spy.append(x) return f(x), 10 * f(x) @my_g.def_dce def my_g_dce(f, used_outs, x): # note: static_argnums are always passed first tracer_spy.append(x) self.assertTrue(callable(f)) return [2 * v if used else None for used, v in zip(used_outs, my_g(f, x))] def my_f(x): return jnp.exp(x) self._check_tracers_and_jaxprs( jax.jit(lambda x: my_g(my_f, x)[0]), 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result", "traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=result[0],result[1]", ], expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_g_dce? "traced_for=custom_dce, fun=my_g, arg_names=x, from x", ]) def test_custom_dce_consts(self): tracer_spy = TracerSpy() @jax.experimental.custom_dce.custom_dce def my_f(x): tracer_spy.append(x) return np.eye(1) * jnp.sin(x), jnp.cos(x) @my_f.def_dce def my_rule(used_outs, y): tracer_spy.append(y) return ( np.full((1, 1), 2.0) * jnp.exp(y) if used_outs[0] else None, jnp.sqrt(y) if used_outs[1] else None, ) self._check_tracers_and_jaxprs( jax.jit(lambda x: my_f(x)[0]), np.array(1.1234), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result", # TODO(necula): bad arg_names (why None), bad result_paths 'traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]', ], expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_rule? "traced_for=custom_dce, fun=my_f, arg_names=x, from x", ]) def test_custom_linear_solve_complex(self): tracer_spy = TracerSpy() def solve(a, b): tracer_spy.append(a) def my_solve(matvec, x): tracer_spy.append(x) return jsp.linalg.solve(a, x) def my_high_precision_dot(a, b): tracer_spy.append(a) return lax.dot(a, b, precision=lax.Precision.HIGHEST) def my_tr_solve(matvec, x): tracer_spy.append(x) return jsp.linalg.solve(a.T, x) matvec = functools.partial(my_high_precision_dot, a) return lax.custom_linear_solve(matvec, b, my_solve, my_tr_solve) rng = self.rng() a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2) b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2) self._check_tracers_and_jaxprs( jax.jit(lambda a, b: jax.jvp(solve, (a, b), (a, b))), a, b, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=a,b, result_paths=result[0],result[1]", re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths=result"), re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths=result"), re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths=result"), # TODO(necula): why pointers to internal functions, arg_names, result_paths? re.compile(r'traced_for=custom_linear_solve solve, fun= at .*linalg.py:.*, arg_names=,,x, result_paths='), re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=,,x, result_paths='), re.compile(r'traced_for=custom_linear_solve, fun= at .*linalg.py:.*, arg_names=,x, result_paths='), re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=,x, result_paths='), "traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=,b, result_paths=result", "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=,x, result_paths=result", "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=,x, result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x, from x", "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x, from x", "None", # TODO(necula): there are missing debug info ]) def test_custom_root_errors(self): tracer_spy = TracerSpy() def dummy_root_usage(x): tracer_spy.append(x) def my_f(x): tracer_spy.append(x) return x - 3. def my_solve(f, x): tracer_spy.append(x) return x def my_transpose_solve(f, x): tracer_spy.append(x) return x return lax.custom_root(my_f, 0., my_solve, my_transpose_solve) self._check_tracers_and_jaxprs( jax.jit(lambda x: jax.jvp(dummy_root_usage, (x,), (0.0,))), 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result[0],result[1]", # TODO(necula): internal function? re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=result\[0\]"), ], expected_tracer_debug_infos=[ "traced_for=custom_root, fun=my_f, arg_names=x, from x", "traced_for=custom_root solve, fun=my_solve, arg_names=x, from x", # TODO(necula): from None "traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x, from None", "None", # TODO(necula): there are missing debug info ]) def test_pallas_call(self): tracer_spy = TracerSpy() def my_kernel(x_ref, y_ref, o_ref): tracer_spy.append(x_ref) o_ref[...] = x_ref[...] + y_ref[...] x = np.arange(256 * 16, dtype=np.float32).reshape((256, 16)) def my_f(x): def my_index_map(i, j): tracer_spy.append(i) return (i, j) return pl.pallas_call(my_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), in_specs=(pl.BlockSpec((128, 8), my_index_map), pl.BlockSpec((128, 8), my_index_map)), out_specs=pl.BlockSpec((128, 8), my_index_map), grid=(pl.cdiv(x.shape[0], 128), pl.cdiv(x.shape[1], 8)), name="my_custom_kernel_name")(x, x) self._check_tracers_and_jaxprs( jax.jit(my_f), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=result[0],result[1]", # TODO(necula): result_paths? "traced_for=pallas_call kernel, fun=my_custom_kernel_name, arg_names=x_ref,y_ref,o_ref, result_paths=", ], expected_tracer_debug_infos=[ "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, from i", "traced_for=pallas_call kernel, fun=my_custom_kernel_name, arg_names=x_ref,y_ref,o_ref, from x_ref", ], check_lowering=False, # We need interpret mode on CPU. TODO(necula) ) def test_checkify_pallas_call(self): tracer_spy = TracerSpy() def kernel(x_ref, y_ref): tracer_spy.append(x_ref) y_ref[...] = jnp.log(x_ref[...]) def my_f(input): out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) pallas_call = pl.pallas_call(kernel, out_shape=out_shape, name="my_custom_kernel_name") checked_call = checkify.checkify(pallas_call, errors=checkify.nan_checks) return checked_call(input)[1] self._check_tracers_and_jaxprs( jax.jit(my_f), jnp.arange(4, dtype=jnp.float32) - 2, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=input, result_paths=result", # TODO(necula): function source location points in JAX internals # TODO(necula): arg_names and result_paths are wrong re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*.pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="), re.compile(r"traced_for=pallas_call index_map, fun=default_index_map at .*.pallas.core.py:.*, arg_names=, result_paths=result\[0\].*"), ], expected_tracer_debug_infos=[ "traced_for=pallas_call kernel, fun=my_custom_kernel_name, arg_names=x_ref,y_ref, from x_ref", ], check_lowering=False, # We need interpret mode on CPU. TODO(necula) ) def test_composite(self): tracer_spy = TracerSpy() scale = np.array([0.5, 0.4, 0.3], dtype=np.float32) @functools.partial(lax.composite, name="my.consts") def my_consts(x): tracer_spy.append(x) return x / scale x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32) self._check_tracers_and_jaxprs( jax.jit(my_consts), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_consts, arg_names=x, result_paths=result", "traced_for=composite, fun=my_consts, arg_names=x, result_paths=result", ], expected_tracer_debug_infos=[ "traced_for=composite, fun=my_consts, arg_names=x, from x"]) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())