mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
1974 lines
73 KiB
Python
1974 lines
73 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import operator
|
|
import re
|
|
from typing import Any
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
import jax
|
|
from jax import ad_checkpoint
|
|
from jax import lax
|
|
|
|
import jax.custom_batching
|
|
import jax.custom_derivatives
|
|
from jax.experimental import checkify
|
|
import jax.experimental.custom_dce
|
|
from jax.experimental import pallas as pl
|
|
from jax.experimental.shard_map import shard_map
|
|
import jax.numpy as jnp
|
|
import jax.scipy as jsp
|
|
|
|
from jax.sharding import Mesh
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
from jax._src import api_util
|
|
from jax._src.ad_checkpoint import saved_residuals
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import custom_transpose
|
|
from jax._src import test_util as jtu
|
|
from jax._src.compilation_cache import is_persistent_cache_enabled
|
|
from jax._src.lax.control_flow import for_loop
|
|
from jax._src.interpreters import mlir
|
|
|
|
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
jtu.request_cpu_devices(8)
|
|
|
|
|
|
def _collect_jaxprs(jaxpr: core.Jaxpr,
|
|
acc: list[core.Jaxpr] | None = None) -> list[core.Jaxpr]:
|
|
"""Collect all Jaxprs in a depth-first order."""
|
|
if acc is None: acc = []
|
|
acc.append(jaxpr)
|
|
for e in jaxpr.eqns:
|
|
# Take first the block mapping Jaxprs
|
|
if e.primitive.name == "pallas_call":
|
|
# For pallas_call, extract also jaxprs inside the grid_mapping
|
|
mapping = e.params["grid_mapping"]
|
|
for bm in mapping.block_mappings:
|
|
_collect_jaxprs(bm.index_map_jaxpr.jaxpr, acc)
|
|
for sj in core.jaxprs_in_params(e.params):
|
|
_collect_jaxprs(sj, acc)
|
|
return acc
|
|
|
|
|
|
def _debug_info_to_string(dbg: core.DebugInfo) -> list[str]:
|
|
# Strip the absolute path and the line number but check that it references
|
|
# this file (to catch errors when the source info points in JAX internals)
|
|
func_src_info = re.sub(r"^(\S+)( at .*.debug_info_test.py:\d+)?", "\\1", dbg.func_src_info)
|
|
arg_names_str = ",".join([str(a) for a in dbg.arg_names])
|
|
res = f"traced_for={dbg.traced_for}, fun={func_src_info}, arg_names={arg_names_str}"
|
|
if isinstance(dbg.result_paths, tuple):
|
|
res += f", result_paths={','.join(dbg.result_paths)}"
|
|
elif dbg.result_paths is None:
|
|
res += ", result_paths=<empty>"
|
|
return res
|
|
|
|
|
|
class TracerSpy:
|
|
"""Use to inspect tracers.
|
|
|
|
We can `append` tracers from tracing contexts to this object. We collect
|
|
the tracer, along with the error message we get when we try to concretize
|
|
it. This is meant to simulate errors like concretization or leaking.
|
|
"""
|
|
tracers: list[tuple[core.Tracer, Exception]]
|
|
def __init__(self):
|
|
self.tracers = []
|
|
|
|
def append(self, t: Any) -> None:
|
|
if isinstance(t, core.Tracer):
|
|
try:
|
|
# We plan to do boolean conversion and catch the exception, but this works
|
|
# only for scalars
|
|
t_scalar = t
|
|
while t_scalar.shape:
|
|
t_scalar = t_scalar[0]
|
|
if t_scalar:
|
|
pass
|
|
assert False, t_scalar
|
|
except Exception as e:
|
|
self.tracers.append((t, e))
|
|
|
|
|
|
@jtu.with_config(jax_mutable_array_checks=True)
|
|
class DebugInfoTest(jtu.JaxTestCase):
|
|
|
|
def _check_tracers_and_jaxprs(self, traceable: Any,
|
|
*args,
|
|
expected_jaxpr_debug_infos: list[str | re.Pattern],
|
|
tracer_spy: TracerSpy | None = None,
|
|
expected_tracer_debug_infos: list[str | re.Pattern] = [],
|
|
check_lowering: bool = True,
|
|
expected_lowering_lines: list[str | re.Pattern] = [],
|
|
**kwargs) -> None:
|
|
"""Checks the expected debug info in all jaxprs, in spied tracers, and StableHLO.
|
|
|
|
`traceable` will be traced as `traceable.trace(*args, **kwargs)` if it has
|
|
a `trace` method (for jit), or will be called as `traceable(*args, **kwargs)`
|
|
otherwise (for eager). We collect all the nested Jaxprs, either from
|
|
the result of `trace`, or by capturing all the lowered jaxprs in eager
|
|
mode. The debug infos in the nested Jaxprs are first converted to
|
|
strings using `_debug_info_to_string` and then
|
|
compared against `expected_jaxpr_debug_infos`. During this conversion,
|
|
we strip occurences of this test file name and a line number
|
|
(e.g., .*/debug_info_test.py:56)
|
|
An element of `expected_jaxpr_debug_infos` can be a string, in which case
|
|
it is compared by equality, or a `re.Pattern` (the result of `re.compile`)
|
|
in which case it is compared by `.match()`. All elements of
|
|
`expected_jaxpr_debug_infos` must appear, and all Jaxprs must be matched.
|
|
|
|
Optionally, we can pass a TracerSpy object into which we have
|
|
`append`ed tracers from the execution of `traceable`. E.g., if we
|
|
do `tracer_spy.append(a)`, where `a` is an argument of a `jit` function,
|
|
we expect to see debugging info
|
|
"traced_for=jit, fun=my_f, arg_names=a,b, from a". These debugging infos
|
|
are compared with `expected_tracer_debug_infos`.
|
|
|
|
Finally, if we pass `expected_lowering_lines` then we are looking for those
|
|
matches in the StableHLO MLIR modules that are lowered.
|
|
"""
|
|
if hasattr(traceable, "trace"):
|
|
traced = traceable.trace(*args, **kwargs)
|
|
all_jaxprs = _collect_jaxprs(traced.jaxpr.jaxpr)
|
|
else:
|
|
# Just run the function and collect the Jaxprs and modules that are
|
|
# lowered
|
|
traced = None
|
|
with jtu.collect_lowered_jaxprs() as collection:
|
|
traceable(*args, **kwargs)
|
|
|
|
all_jaxprs = []
|
|
for jaxpr, _ in collection:
|
|
all_jaxprs.extend(_collect_jaxprs(jaxpr.jaxpr))
|
|
|
|
found_jaxprs_debug_infos = [_debug_info_to_string(j.debug_info)
|
|
for j in all_jaxprs]
|
|
|
|
self._check_matches(expected_jaxpr_debug_infos, found_jaxprs_debug_infos,
|
|
"Jaxprs debug_infos") # JAXPRS
|
|
|
|
found_tracer_debug_infos = []
|
|
if tracer_spy is not None:
|
|
for t, exc in tracer_spy.tracers:
|
|
if hasattr(t, "_debug_info"):
|
|
t_debug_info = _debug_info_to_string(t._debug_info)
|
|
msg = str(exc)
|
|
m = re.match(r".* while tracing the function (.+) for ([^.]+)\.",
|
|
msg,
|
|
re.DOTALL)
|
|
if m is None:
|
|
found_tracer_debug_infos.append(
|
|
f"{t_debug_info}, from None")
|
|
else:
|
|
self.assertIsNotNone(m, msg)
|
|
self.assertEqual(t._debug_info.func_src_info, m.group(1))
|
|
self.assertEqual(t._debug_info.traced_for, m.group(2))
|
|
m = re.match(r".* depends on the value of the argument ([^\n]+)\.",
|
|
msg,
|
|
re.DOTALL)
|
|
found_tracer_debug_infos.append(
|
|
f"{t_debug_info}, from {m.group(1) if m else None}")
|
|
else:
|
|
found_tracer_debug_infos.append("None")
|
|
|
|
self._check_matches(expected_tracer_debug_infos, found_tracer_debug_infos,
|
|
"Tracer debug_infos") # INSPECTED TRACERS
|
|
|
|
if not check_lowering: return
|
|
# Collect all the lines in all the MLIR modules
|
|
mlir_modules_lines = []
|
|
if traced is not None:
|
|
mlir_modules_lines.extend(
|
|
traced.lower().as_text("stablehlo", debug_info=True).split("\n"))
|
|
else:
|
|
for _, mod in collection:
|
|
mlir_modules_lines.extend(
|
|
mlir.module_to_string(mod, enable_debug_info=True).split("\n"))
|
|
|
|
self._check_matches(expected_lowering_lines, mlir_modules_lines,
|
|
"MLIR module lines", report_found_unexpected=False)
|
|
|
|
def _check_matches(self,
|
|
expected: list[str | re.Pattern],
|
|
found: list[str],
|
|
what: str,
|
|
report_found_unexpected: bool = True) -> None:
|
|
expected_and_found: set[str | re.Pattern] = set()
|
|
found_and_expected: set[str] = set()
|
|
for exp_re in expected:
|
|
for found_line in found:
|
|
ok = exp_re.match(found_line) if isinstance(exp_re, re.Pattern) else exp_re == found_line
|
|
if ok:
|
|
expected_and_found.add(exp_re)
|
|
found_and_expected.add(found_line)
|
|
|
|
found_and_unexpected = set(found) - found_and_expected
|
|
all_found = "\n ".join(found)
|
|
if report_found_unexpected and found_and_unexpected:
|
|
unexp_str = "\n ".join(found_and_unexpected)
|
|
msg = f"Found unexpected {what}:\n {unexp_str}\nAll found {what}:\n {all_found}"
|
|
self.assertTrue(False, msg)
|
|
|
|
if expected_not_found := {e for e in expected if e not in expected_and_found}:
|
|
exp_str = "\n ".join([str(e) for e in expected_not_found])
|
|
msg = f"Expected but not found in {what}:\n {exp_str}\nAll found {what}:\n {all_found}"
|
|
self.assertTrue(False, msg)
|
|
|
|
def test_debug_info_basic(self):
|
|
def my_f(x, y, z, w):
|
|
pass
|
|
|
|
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
|
|
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
|
|
self.assertEqual(dbg.func_name, "my_f")
|
|
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
|
|
self.assertIsNone(dbg.result_paths)
|
|
|
|
def test_debug_info_arg_passed_as_kwarg(self):
|
|
def my_f(x, y, z):
|
|
pass
|
|
|
|
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3))
|
|
self.assertEqual(dbg.arg_names, ("x", "y", "z"))
|
|
|
|
def test_debug_info_pytrees(self):
|
|
def my_f(x_tree, *, y_tree):
|
|
pass
|
|
|
|
dbg = api_util.debug_info("jit", my_f, ((1, 2),),
|
|
dict(y_tree=dict(z=3, w=4)))
|
|
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
|
|
"y_tree['w']", "y_tree['z']"))
|
|
|
|
def test_debug_info_with_statics(self):
|
|
def my_f(x, y, *, z, w):
|
|
pass
|
|
|
|
dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
|
|
static_argnums=(1,),
|
|
static_argnames=("w",))
|
|
self.assertEqual(dbg.arg_names, ("x", "z"))
|
|
|
|
def test_debug_info_with_pytrees_and_statics(self):
|
|
def my_f(x, y, *, z, w):
|
|
pass
|
|
|
|
dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)),
|
|
dict(z=(3, 4), w=(5, 6)),
|
|
static_argnums=(1,),
|
|
static_argnames=("w",))
|
|
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))
|
|
|
|
def test_debug_info_too_many_args(self):
|
|
def my_f(x):
|
|
pass
|
|
|
|
dbg = api_util.debug_info("jit", my_f, (1, 2, 3), dict(z=3))
|
|
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))
|
|
|
|
def test_debug_info_no_source_info_built_in(self):
|
|
# built-in function "int" does not have an inspect.Signature
|
|
dbg = api_util.debug_info("jit", max, (1,), {})
|
|
self.assertEqual(dbg.func_src_info, "max")
|
|
self.assertEqual(dbg.arg_names, ("args[0]",))
|
|
|
|
def test_debug_info_lambda(self):
|
|
# built-in function "int" does not have an inspect.Signature
|
|
dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {})
|
|
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
|
|
self.assertEqual(dbg.arg_names, ("my_arg",))
|
|
|
|
def test_debug_info_save_wrapped_fun_source_info(self):
|
|
def wrapper(x, y):
|
|
return x
|
|
|
|
dbg = api_util.debug_info("test", wrapper, (1, 2), {})
|
|
self.assertEqual("wrapper", dbg.func_name)
|
|
|
|
api_util.save_wrapped_fun_sourceinfo(wrapper, lambda x, y: x)
|
|
dbg = api_util.debug_info("test", wrapper, (1, 2), {})
|
|
self.assertEqual("<lambda>", dbg.func_name)
|
|
|
|
def other_f():
|
|
pass
|
|
dbg_other = api_util.debug_info("test other", other_f, (), {})
|
|
api_util.save_wrapped_fun_sourceinfo(wrapper, dbg_other)
|
|
dbg = api_util.debug_info("test", wrapper, (1, 2), {})
|
|
self.assertEqual("other_f", dbg.func_name)
|
|
self.assertEqual("test", dbg.traced_for)
|
|
|
|
def test_debug_info_no_source_info_not_callable(self):
|
|
# built-in function "int" does not have an inspect.Signature
|
|
dbg = api_util.debug_info("jit", False, (1,), {})
|
|
self.assertEqual(dbg.func_src_info, "<unknown>")
|
|
self.assertEqual(dbg.arg_names, ("args[0]",))
|
|
|
|
def test_debug_info_no_source_info_callable(self):
|
|
class Foo:
|
|
x: int
|
|
|
|
def __call__(self, y):
|
|
return self.x + y
|
|
|
|
dbg = api_util.debug_info("jit", Foo(), (1,), {})
|
|
self.assertRegex(dbg.func_src_info, "<unknown>")
|
|
self.assertEqual(dbg.arg_names, ("y",))
|
|
|
|
def test_debug_info_no_source_info_callable_with_repr_errors(self):
|
|
class Foo:
|
|
x: int
|
|
|
|
def __call__(self, y):
|
|
return self.x + y
|
|
|
|
def __repr__(self):
|
|
raise NotImplementedError
|
|
|
|
dbg = api_util.debug_info("jit", Foo(), (1,), {})
|
|
self.assertRegex(dbg.func_src_info, "<unknown>")
|
|
self.assertEqual(dbg.arg_names, ("y",))
|
|
|
|
def helper_save_tracer(self, x):
|
|
self._saved_tracer = x
|
|
return x
|
|
|
|
def test_jit_lower_arg_names_with_error1(self):
|
|
def f(x):
|
|
return jnp.sqrt(x ** 2) + 1.
|
|
|
|
x = jnp.array(1, dtype=int)
|
|
x_f32 = x.astype(jnp.float32)
|
|
x_i32 = x.astype(jnp.int32)
|
|
f_exe = jax.jit(f).lower(x_f32).compile()
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
r"Argument types differ .*"
|
|
r"The mismatches are:\n"
|
|
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
|
|
lambda: f_exe(x_i32))
|
|
|
|
def test_jit_lower_arg_names_with_error2(self):
|
|
def f(x):
|
|
return x
|
|
|
|
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
|
|
"value is of type .* and was passed to the function at path x.")
|
|
with self.assertRaisesRegex(TypeError, err_str):
|
|
jax.jit(f)("foo")
|
|
|
|
# Jax type objects aren't valid data arguments.
|
|
with self.assertRaisesRegex(TypeError, err_str):
|
|
jax.jit(f)(jnp.int32)
|
|
|
|
@jtu.thread_unsafe_test() # logging is not thread-safe
|
|
def test_arg_names_cache_miss_explanations(self):
|
|
@jax.jit
|
|
def f(x, y):
|
|
return jnp.sin(x) * y['hi']
|
|
|
|
x = jnp.float32(1.)
|
|
y = {'hi': jnp.arange(3., dtype='float32')}
|
|
|
|
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
|
|
|
# print on first miss, not on hit
|
|
with config.explain_cache_misses(True):
|
|
with self.assertLogs(level='WARNING') as cm:
|
|
f(x, y)
|
|
f(x, y)
|
|
self.assertLen(cm.output, expected_log_len)
|
|
msg = cm.output[0]
|
|
self.assertIn('TRACING CACHE MISS', msg)
|
|
self.assertIn('never seen function', msg)
|
|
|
|
# shape change
|
|
y_ = {'hi': jnp.arange(4, dtype='float32')}
|
|
with config.explain_cache_misses(True):
|
|
with self.assertLogs(level='WARNING') as cm:
|
|
f(x, y_)
|
|
self.assertLen(cm.output, expected_log_len)
|
|
msg = cm.output[0]
|
|
self.assertIn('never seen input type signature', msg)
|
|
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
|
self.assertIn('seen f32[3], but now given f32[4]', msg)
|
|
|
|
# weak type change (assuming no x64)
|
|
if not config.enable_x64.value:
|
|
with config.explain_cache_misses(True):
|
|
with self.assertLogs(level='WARNING') as cm:
|
|
f(1., y)
|
|
self.assertLen(cm.output, expected_log_len)
|
|
msg = cm.output[0]
|
|
self.assertIn('weak_type=True', msg)
|
|
self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg)
|
|
|
|
# kwarg change
|
|
with config.explain_cache_misses(True):
|
|
with self.assertLogs(level='WARNING') as cm:
|
|
f(1, y=y)
|
|
self.assertLen(cm.output, expected_log_len)
|
|
msg = cm.output[0]
|
|
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
|
|
|
|
# tracing config change
|
|
with config.explain_cache_misses(True):
|
|
with self.assertLogs(level='WARNING') as cm:
|
|
with jax.numpy_rank_promotion('warn'):
|
|
f(x, y)
|
|
# depending on the backend, we may or may not get persistent cache warnings
|
|
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
|
|
msg = cm.output[0]
|
|
self.assertIn("tracing context doesn't match", msg)
|
|
|
|
@jtu.thread_unsafe_test() # logging is not thread-safe
|
|
def test_arg_names_cache_miss_explanations_new_function_in_loop(self):
|
|
@jax.jit
|
|
def f(x, y):
|
|
return jnp.sin(x) * y['hi']
|
|
|
|
x = jnp.float32(1.)
|
|
|
|
with config.explain_cache_misses(True):
|
|
with self.assertLogs(level='WARNING') as cm:
|
|
for _ in range(2):
|
|
jax.jit(lambda x: 2 * x)(3)
|
|
if is_persistent_cache_enabled():
|
|
# number of warnings depends on the backend
|
|
self.assertTrue(4 <= len(cm.output) <= 6)
|
|
msg = cm.output[3]
|
|
self.assertIn('another function defined on the same line', msg)
|
|
else:
|
|
self.assertLen(cm.output, 2)
|
|
_, msg = cm.output
|
|
self.assertIn('another function defined on the same line', msg)
|
|
|
|
@jtu.thread_unsafe_test() # logging is not thread-safe
|
|
@unittest.skip("Test fails, probably due to caching")
|
|
def test_arg_names_cache_miss_explanations_unpacks_transforms(self):
|
|
# Tests that the explain_tracing_cache_miss() function does not throw an
|
|
# error when unpacking `transforms` with a length greater than 3.
|
|
@jax.jit
|
|
def f(key):
|
|
return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32)
|
|
|
|
with config.explain_cache_misses(True):
|
|
with self.assertLogs(level="WARNING") as cm:
|
|
f(jax.random.key(seed=123))
|
|
|
|
if is_persistent_cache_enabled():
|
|
# 5 warnings from tracing cache, 5-10 from persistent cache depending on
|
|
# the backend
|
|
self.assertTrue(10 <= len(cm.output) <= 15)
|
|
self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output))
|
|
else:
|
|
self.assertLen(cm.output, 5)
|
|
for msg in cm.output:
|
|
self.assertIn("TRACING CACHE MISS", msg)
|
|
|
|
def test_arg_names_cache_miss_explanations_no_source_info(self):
|
|
# ``operator.add`` is a built-in function and does not have source info.
|
|
with config.explain_cache_misses(True):
|
|
jax.jit(operator.add)(42, 24)
|
|
|
|
def test_concrete_error_because_arg_unary(self):
|
|
@jax.jit
|
|
def f(x):
|
|
if x > 0:
|
|
return x
|
|
else:
|
|
return 0
|
|
|
|
msg = r"on the value of the argument x"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1)
|
|
|
|
def test_concrete_error_because_arg_binary(self):
|
|
@jax.jit
|
|
def f(x, y):
|
|
if x > y:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the arguments x and y"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2)
|
|
|
|
def test_concrete_error_because_arg_ternary(self):
|
|
@jax.jit
|
|
def f(x, y, z):
|
|
if x > z:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the arguments x and z"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2, 3)
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2, z=3)
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, y=2, z=3)
|
|
|
|
def test_concrete_error_because_arg_varargs(self):
|
|
@jax.jit
|
|
def f(*args):
|
|
x, y, z = args
|
|
if x > z:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the arguments args"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2, 3)
|
|
|
|
def test_concrete_error_because_arg_kwargs(self):
|
|
@jax.jit
|
|
def f(**kwargs):
|
|
x, y, z = kwargs['x'], kwargs['y'], kwargs['z']
|
|
if x > z:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the arguments kwargs"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(x=1, y=2, z=3)
|
|
|
|
def test_concrete_error_because_arg_pytree(self):
|
|
@jax.jit
|
|
def f(xy, z):
|
|
x, y = xy
|
|
if x > 0:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the value of the argument xy"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f((1, 2), z=3)
|
|
|
|
def test_concrete_error_because_const(self):
|
|
@jax.jit
|
|
def f():
|
|
assert jnp.add(1, 1) > 0
|
|
|
|
msg = "on these lines"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f()
|
|
|
|
def test_concrete_error_because_const_2(self):
|
|
@jax.jit
|
|
def f():
|
|
result = sum(jnp.add(1, 1) for _ in range(6))
|
|
assert result > 0
|
|
|
|
msg = "Additional originating lines are not shown."
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f()
|
|
|
|
def test_concrete_error_with_nested_call(self):
|
|
@jax.jit
|
|
def f(x, y):
|
|
if y:
|
|
return x
|
|
|
|
@jax.jit
|
|
def g(x):
|
|
return f(x, True)
|
|
|
|
msg = r"on the value of the argument y"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
g(1)
|
|
|
|
def test_remat_concrete_error(self):
|
|
@jax.remat # no static_argnums or concrete
|
|
def g(x):
|
|
if x > 0:
|
|
return lax.sin(x)
|
|
else:
|
|
return lax.cos(x)
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
|
|
g(3.)
|
|
|
|
@functools.partial(jax.remat, static_argnums=(0,)) # using static_argnums but...
|
|
def g(x):
|
|
if x > 0: # jnp operations still get staged!
|
|
return lax.sin(x)
|
|
else:
|
|
return lax.cos(x)
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
|
|
g(jnp.array(3.))
|
|
|
|
# But don't raise an error mentioning static_argnums here:
|
|
@jax.remat
|
|
def g(x):
|
|
jax.jit(lambda: 0 if jnp.add(1, 1) else 0)()
|
|
return lax.sin(x)
|
|
|
|
try:
|
|
g(jnp.array(3.))
|
|
except core.ConcretizationTypeError as e:
|
|
msg = str(e)
|
|
self.assertNotIn('static_argnums', msg)
|
|
|
|
def test_simple_jit(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x_dict, y):
|
|
tracer_spy.append(x_dict["a"])
|
|
return dict(c=x_dict["a"] + x_dict["b"], d=y)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
dict(a=1, b=2), 3,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=result['c'],result['d']"
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, from x_dict['a']",
|
|
])
|
|
|
|
def test_jit_with_static_argnums(self):
|
|
tracer_spy = TracerSpy()
|
|
@functools.partial(jax.jit, static_argnums=(1,))
|
|
def my_f(a, d):
|
|
tracer_spy.append(a)
|
|
return a
|
|
|
|
def my_g(b, d=1):
|
|
tracer_spy.append(b)
|
|
return my_f(b, d)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_g),
|
|
3,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
# TODO(necula): result_paths?
|
|
"traced_for=jit, fun=my_f, arg_names=a, result_paths=",
|
|
"traced_for=jit, fun=my_g, arg_names=b, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=my_g, arg_names=b, from b",
|
|
"traced_for=jit, fun=my_f, arg_names=a, from a",
|
|
])
|
|
|
|
def test_jit_arg_names(self):
|
|
tracer_spy = TracerSpy()
|
|
def f(x, y, *args, **kwargs):
|
|
# args[0] is dead
|
|
tracer_spy.append(kwargs["w"])
|
|
tracer_spy.append(args[0])
|
|
return y['hi'] + args[1] + sum(kwargs.values())
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(f),
|
|
{"ho": 1.}, {"hi": 2.}, 3., 4., z=5., w=6.,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=result",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from kwargs['w']",
|
|
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from args[0]",
|
|
],
|
|
expected_lowering_lines=[
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"y\['hi'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['w'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(\"kwargs\['z'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
|
|
]
|
|
)
|
|
|
|
def test_jit_arg_names_static_argnums(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y, z, *args, **kwargs):
|
|
# z and args[0] and kwargs["w"] are dead
|
|
# x and args[2] are static
|
|
tracer_spy.append(kwargs["w"])
|
|
tracer_spy.append(x[0])
|
|
return x[0] + y["hi"] + args[1] + args[2] + kwargs["t"]
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f, static_argnums=(0, 5)),
|
|
(1.,), {"hi": 2.}, 3., 4., 5., 6., # x, y, z, args[0], args[1], args[2]
|
|
t=11., w=12., # kwargs
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=result",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']",
|
|
],
|
|
expected_lowering_lines=[
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"y\['hi'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['t'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
|
|
])
|
|
|
|
def test_jit_arg_names_static_argnames(self):
|
|
tracer_spy = TracerSpy()
|
|
def f(x, y, *args, **kwargs):
|
|
# x and args[0] and kwargs["z"] are dead
|
|
# kwargs[a] is static
|
|
tracer_spy.append(x[0])
|
|
return y['hi'] + args[1] + kwargs['a'] + kwargs['b'] + kwargs['w']
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(f, static_argnames=("a",)),
|
|
(1.,), {'hi': 2.}, 3., 4., # x, y, args[0], args[1]
|
|
z=5., w=6., a=7., b=8., # kwargs
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=result",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], from x[0]",
|
|
],
|
|
expected_lowering_lines=[
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"y\['hi'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['b'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(\"kwargs\['w'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
|
|
])
|
|
|
|
def test_jit_result_info(self):
|
|
def f(x, y, z):
|
|
return {'a': x, 'b': [y]}
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(f),
|
|
1., (2.,), [3.],
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=result['a'],result['b'][0][0]",
|
|
],
|
|
expected_lowering_lines=[
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"x\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"y\[0\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['a'\]\"\}"),
|
|
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['b'\]\[0\]\[0\]\"\}"),
|
|
])
|
|
|
|
def test_nested_jit(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y):
|
|
tracer_spy.append(x)
|
|
|
|
def my_g(u, v):
|
|
tracer_spy.append(u)
|
|
return dict(c=u * v, d=v)
|
|
|
|
return jax.jit(my_g)(y, x)["c"]
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
2, 3,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']"
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, from x",
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, from u"
|
|
])
|
|
|
|
def test_nested_jit_with_const_and_unused_args(self):
|
|
def my_f(x, y): # y is unused
|
|
def my_g(u, v): # v is unused
|
|
return v + np.ones(v.shape, v.dtype)
|
|
|
|
return x + jax.jit(my_g)(y, x)
|
|
|
|
x = y = np.ones((8,), dtype=np.float32)
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
x, y,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result",
|
|
],
|
|
expected_lowering_lines=[
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"),
|
|
re.compile(r".*call @my_g\(%arg.\) : \(tensor<8xf..>\)"),
|
|
]
|
|
)
|
|
|
|
def test_jvp_of_jit(self):
|
|
tracer_spy = TracerSpy()
|
|
def f(x, y, z):
|
|
tracer_spy.append(x)
|
|
return {'a': x, 'b': [y]}
|
|
self._check_tracers_and_jaxprs(
|
|
lambda x, y, z: jax.jvp(jax.jit(f), (x, y, z), (x, y, z)),
|
|
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
|
|
expected_jaxpr_debug_infos=[
|
|
# TODO(necula): arg_names, result_paths
|
|
"traced_for=jit, fun=f, arg_names=,,,, result_paths=,,,",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], from x",
|
|
],
|
|
expected_lowering_lines=[
|
|
# TODO(necula): missing arg_names
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(unknown\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(unknown\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(unknown\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(unknown\)"),
|
|
# TODO(necula): missing result names
|
|
re.compile(r".*func.func public @main\(.*-> .*tensor<f..> {jax.result_info = \"\"}"),
|
|
])
|
|
|
|
def test_vjp_of_jit(self):
|
|
# TODO(b/398208230): Re-enable this test after fixing.
|
|
self.skipTest("Enable this after figuring out why it's failing")
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y, z):
|
|
tracer_spy.append(y[0])
|
|
return {'a': x * y[0], 'b': [y]}
|
|
self._check_tracers_and_jaxprs(
|
|
lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])),
|
|
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=result",
|
|
re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths=result"),
|
|
# TODO(necula): arg_names? result_paths?
|
|
"traced_for=jit, fun=my_f, arg_names=,,,, result_paths=['a'],['b'][0][0]",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x,y[0],z[0], from y[0]",
|
|
],
|
|
expected_lowering_lines=[
|
|
# TODO(necula): missing arg_names
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(unknown\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(unknown\)"),
|
|
# TODO(necula): result_paths?
|
|
re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
|
|
])
|
|
|
|
def test_vjp_of_nested_jit(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y):
|
|
tracer_spy.append(x)
|
|
|
|
def my_g(u, v):
|
|
tracer_spy.append(u)
|
|
return dict(c=u * v, d=v)
|
|
|
|
return jax.jit(my_g)(y, x)["c"]
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x, y, res_ct: jax.vjp(my_f, x, y)[1](res_ct)),
|
|
2., 3., 0.3,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=result[0],result[1]",
|
|
# TODO(necula): result_paths
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
|
|
# TODO(necula): arg_names
|
|
"traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=,"
|
|
if config.use_direct_linearize.value else
|
|
"traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): missing debug info
|
|
"None",
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, from u"
|
|
],
|
|
expected_lowering_lines=[
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"x\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"y\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"res_ct\"\)"),
|
|
re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[0\]\"}"),
|
|
re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[1\]\"}"),
|
|
])
|
|
|
|
def test_vjp_remat(self):
|
|
tracer_spy = TracerSpy()
|
|
def apply_fn(inp):
|
|
tracer_spy.append(inp)
|
|
def to_remat(x):
|
|
tracer_spy.append(x)
|
|
return jax.nn.relu(x * x)
|
|
fn = jax.checkpoint(to_remat)
|
|
return jax.vjp(fn, inp)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(apply_fn),
|
|
2.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
# TODO(necula): what are these flat_index components?
|
|
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=result[0],result[1][0][0][0][0][0]",
|
|
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"),
|
|
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"),
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x",
|
|
"traced_for=jit, fun=apply_fn, arg_names=inp, from inp",
|
|
])
|
|
|
|
def test_custom_jvp(self):
|
|
tracer_spy = TracerSpy()
|
|
@jax.custom_jvp
|
|
def my_fun(x, y, c=1.):
|
|
tracer_spy.append(y)
|
|
return c * (x + y)
|
|
def my_jvp(primals, tangents):
|
|
x, y, c = primals
|
|
t_x, t_y, t_c = tangents
|
|
tracer_spy.append(t_y)
|
|
return my_fun(x, y, c), t_c
|
|
my_fun.defjvp(my_jvp)
|
|
|
|
def top_f(x, y):
|
|
return jnp.square(my_fun(x, y, c=2.)).sum()
|
|
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda a: jax.jvp(top_f, (a, a),
|
|
(jnp.ones_like(a), jnp.ones_like(a)))),
|
|
42.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=a, result_paths=result[0],result[1]",
|
|
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): from None?
|
|
"traced_for=jit, fun=<lambda>, arg_names=a, from None",
|
|
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y",
|
|
])
|
|
|
|
def test_custom_jvp_nondiff_args(self):
|
|
tracer_spy = TracerSpy()
|
|
def top_f(xy):
|
|
tracer_spy.append(xy[0])
|
|
@functools.partial(jax.custom_jvp, nondiff_argnums=(0,))
|
|
def my_g(h, xy):
|
|
x, y = xy
|
|
tracer_spy.append(x)
|
|
return h(x)
|
|
@my_g.defjvp
|
|
def my_g_jvp(h, primals, tangents):
|
|
(x, y), = primals
|
|
(xt, yt), = tangents
|
|
tracer_spy.append(xt)
|
|
return my_g(h, (x, y)), 2. * xt
|
|
h = lambda y: xy[0] + y # capture x
|
|
return my_g(h, xy)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda a, b: jax.jvp(top_f, ((a, b),),
|
|
((jnp.ones_like(a), jnp.ones_like(b)),))),
|
|
42., 43.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
# TODO(necula): arg_names
|
|
"traced_for=jit, fun=<lambda>, arg_names=,a,b, result_paths=result[0],result[1]",
|
|
"traced_for=custom_jvp fun, fun=my_g, arg_names=,xy[0],xy[1], result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]",
|
|
# TODO(necula): from None
|
|
"traced_for=jit, fun=<lambda>, arg_names=a,b, from None",
|
|
"None", # TODO(necula): None
|
|
])
|
|
|
|
def test_custom_vjp(self):
|
|
tracer_spy = TracerSpy()
|
|
@jax.custom_vjp
|
|
def my_f(x):
|
|
tracer_spy.append(x["a"])
|
|
return {"b": jnp.sin(x["a"])}
|
|
def my_f_fwd(x):
|
|
tracer_spy.append(x["a"])
|
|
return my_f(x), {"r": jnp.cos(x["a"])}
|
|
def my_f_bwd(res, g):
|
|
tracer_spy.append(g["b"])
|
|
cos_x = res["r"]
|
|
return ({"a": 2 * cos_x * g["b"]},)
|
|
my_f.defvjp(my_f_fwd, my_f_bwd)
|
|
|
|
def to_diff(x):
|
|
return my_f(x)["b"]
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.grad(to_diff)),
|
|
{"a" : 3.},
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=result['a']",
|
|
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=result['b']",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']",
|
|
# TODO(necula): from None?
|
|
"traced_for=jit, fun=to_diff, arg_names=x['a'], from None",
|
|
"traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']",
|
|
])
|
|
|
|
def test_custom_vjp_nondiff_args(self):
|
|
tracer_spy = TracerSpy()
|
|
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
|
|
def app(f, xy):
|
|
tracer_spy.append(xy[0])
|
|
return f(xy)
|
|
def app_fwd(f, xy):
|
|
tracer_spy.append(xy[0])
|
|
return app(f, xy), jnp.cos(xy[0])
|
|
def app_rev(f, cos_x0, g):
|
|
tracer_spy.append(cos_x0)
|
|
tracer_spy.append(g)
|
|
return ((cos_x0 * g, cos_x0),)
|
|
app.defvjp(app_fwd, app_rev)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.grad(lambda xy: app(lambda x: 2 * x[0], xy))),
|
|
(3., 3.),
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], result_paths=result[0],result[1]",
|
|
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from xy[0]",
|
|
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]",
|
|
# TODO(necula): from None
|
|
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from None",
|
|
])
|
|
|
|
def test_custom_transpose(self):
|
|
# Helpers from api_test.py
|
|
class _custom_transpose:
|
|
def __init__(self, out_types, fun):
|
|
self.out_types = out_types
|
|
self.fun = custom_transpose.custom_transpose(fun)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.fun, name)
|
|
|
|
def __call__(self, *args):
|
|
return self.fun(self.out_types, *args)
|
|
|
|
def custom_transpose_with_example_out(example_out):
|
|
return functools.partial(
|
|
_custom_transpose,
|
|
jax.tree.map(
|
|
lambda x: core.get_aval(x).to_tangent_aval(), example_out))
|
|
|
|
tracer_spy = TracerSpy()
|
|
def my_f_with_cond(i, x):
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
@custom_transpose_with_example_out(jnp.ones(2))
|
|
def fn(r, x):
|
|
tracer_spy.append(r)
|
|
tracer_spy.append(x["c"])
|
|
return dict(b=x["c"] / r)
|
|
|
|
@fn.def_transpose
|
|
def fn_tp(r, t):
|
|
tracer_spy.append(r)
|
|
return dict(c=2 * t / r)
|
|
|
|
return x["c"] + fn(jnp.ones(2) * 3., x)
|
|
|
|
return lax.cond(i > 0, my_f, lambda x: x["c"], dict(c=x))
|
|
|
|
x = jnp.ones(2) * 6.
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x: jax.linear_transpose(functools.partial(my_f_with_cond, 7.),
|
|
x)),
|
|
x,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=result",
|
|
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=result",
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0][0][0],result[0][0][1]",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from r",
|
|
"traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from x['c']",
|
|
"traced_for=custom_transpose transpose_fun, fun=fn_tp, arg_names=r,t, from r",
|
|
])
|
|
|
|
def test_linear_call(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y):
|
|
tracer_spy.append(y)
|
|
def fn(r, x):
|
|
tracer_spy.append(r)
|
|
tracer_spy.append(x["c"])
|
|
return dict(b=x["c"] * r)
|
|
def fn_tp(r, t):
|
|
tracer_spy.append(t["b"])
|
|
return dict(c=t["b"] * r)
|
|
return dict(a=x["c"] + jax.custom_derivatives.linear_call(fn, fn_tp, y, x)["b"])
|
|
|
|
f1 = lambda x: my_f(x, jnp.ones(2) * 3.)
|
|
x = jnp.ones(2) * 6.
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x: jax.linear_transpose(f1, dict(c=x))(dict(a=x))),
|
|
x,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0]['c']",
|
|
"traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']",
|
|
"traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): from None?
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, from None",
|
|
"traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], from r",
|
|
"traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], from x['c']",
|
|
"traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], from t['c']",
|
|
]),
|
|
|
|
|
|
|
|
def test_custom_vmap(self):
|
|
tracer_spy = TracerSpy()
|
|
@jax.custom_batching.custom_vmap
|
|
def my_f(xdict):
|
|
x = xdict["x"]
|
|
tracer_spy.append(x)
|
|
return dict(a=jnp.sin(x))
|
|
|
|
@my_f.def_vmap
|
|
def my_rule(axis_size, in_batched, xys):
|
|
xs = xys["x"]
|
|
tracer_spy.append(xs)
|
|
xs_batched, = in_batched
|
|
self.assertEqual(xs_batched["x"], True)
|
|
self.assertEqual(axis_size, xs.shape[0])
|
|
return dict(a=jnp.cos(xs)), dict(a=xs_batched["x"])
|
|
|
|
xy = dict(x=np.ones((8,), dtype=np.float32), y=np.zeros((8,), dtype=np.float32))
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.vmap(my_f)),
|
|
xy,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=result['a']",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']",
|
|
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']"
|
|
])
|
|
|
|
def test_cond(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
def my_true_branch(a, b):
|
|
tracer_spy.append(a)
|
|
return a + b
|
|
|
|
def my_false_branch(c, d):
|
|
tracer_spy.append(c)
|
|
return c - d
|
|
|
|
return lax.cond(x >= 0, my_true_branch, my_false_branch, x, x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
0,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result",
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, from a",
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, from c"
|
|
])
|
|
|
|
def test_switch(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
def my_branch0(x0):
|
|
tracer_spy.append(x0)
|
|
return x0
|
|
def my_branch1(x1):
|
|
tracer_spy.append(x1)
|
|
return x1 + 1
|
|
def my_branch2(x2):
|
|
tracer_spy.append(x2)
|
|
return x2 + 2
|
|
return lax.switch(x, [my_branch0, my_branch1, my_branch2], x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
2,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
"traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=result",
|
|
"traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=result",
|
|
"traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=switch, fun=my_branch0, arg_names=x0, from x0",
|
|
"traced_for=switch, fun=my_branch1, arg_names=x1, from x1",
|
|
"traced_for=switch, fun=my_branch2, arg_names=x2, from x2"
|
|
])
|
|
|
|
def test_grad_cond_with_remat(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y):
|
|
# The cond branches return two things, and only the first is needed
|
|
# in the residuals.
|
|
def my_true_branch(a, b):
|
|
tracer_spy.append(a)
|
|
return (a + 1, a + b)
|
|
|
|
def my_false_branch(c, d):
|
|
tracer_spy.append(c)
|
|
return (c - 1, c - d)
|
|
|
|
def my_g(x, y):
|
|
# x1 does not depend on y
|
|
x1, y1 = lax.cond(x >= 0, my_true_branch, my_false_branch, x, y)
|
|
tracer_spy.append(x1)
|
|
return x1, y1
|
|
|
|
x2, y2 = jax.remat(my_g)(x, y)
|
|
return y2 + lax.sin(x2)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.grad(my_f)),
|
|
1., 2.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
# TODO(necula): arg_names? result_paths?
|
|
"traced_for=cond, fun=my_true_branch, arg_names=, result_paths=,",
|
|
"traced_for=cond, fun=my_false_branch, arg_names=, result_paths=,",
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result[0],result[1]",
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result[0],result[1]",
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=,",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=cond, fun=my_true_branch, arg_names=a,b, from a",
|
|
"traced_for=cond, fun=my_false_branch, arg_names=c,d, from c",
|
|
# TODO(necula): from None
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None",
|
|
])
|
|
|
|
def test_grad_scan(self):
|
|
# Based on control_flow_test:testScanHigherOrderDifferentiation
|
|
tracer_spy = TracerSpy()
|
|
def f(c, a):
|
|
tracer_spy.append(c)
|
|
d = 0.75
|
|
b = jnp.sin(c * jnp.sum(jnp.cos(d * a)))
|
|
c = 0.9 * jnp.cos(d * jnp.sum(jnp.sin(c * a)))
|
|
return c, b
|
|
|
|
as_ = jnp.arange(6.).reshape((3, 2))
|
|
c = jnp.array(1, dtype=as_.dtype)
|
|
|
|
@jax.jit
|
|
def my_f(x, as_):
|
|
tracer_spy.append(x)
|
|
def to_remat(a, b):
|
|
return for_loop.scan(f, a, b)
|
|
return jax.remat(to_remat)(c, as_)
|
|
|
|
def the_grad(c, as_):
|
|
tracer_spy.append(c)
|
|
_, pullback = jax.vjp(my_f, c, as_)
|
|
return pullback((c, np.arange(3, dtype=c.dtype)))
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(the_grad),
|
|
c, as_,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]",
|
|
# TODO(necula): arg names, bad result paths
|
|
"traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,",
|
|
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=",
|
|
"traced_for=for_loop, fun=f, arg_names=,,, result_paths=,",
|
|
"traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,",
|
|
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=",
|
|
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,",
|
|
"traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,",
|
|
"traced_for=jit, fun=my_f, arg_names=as_,,, result_paths="
|
|
if config.use_direct_linearize.value else
|
|
"traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=the_grad, arg_names=c,as_, from c",
|
|
"traced_for=scan, fun=f, arg_names=c,a, from c",
|
|
"traced_for=jit, fun=my_f, arg_names=x,as_, from x",
|
|
# TODO(necula): arg_names, and "from x"
|
|
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], from refs[0]",
|
|
],
|
|
expected_lowering_lines=[
|
|
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"c\"\)"),
|
|
re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"),
|
|
re.compile(r".*func.func public @main\(.* -> .*tensor<f..> {jax.result_info = \"result\[0\]\""),
|
|
re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""),
|
|
# TODO(necula): unnamed function?
|
|
re.compile(r".*func.func private @None"),
|
|
])
|
|
|
|
def test_while_loop(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
def my_cond(a):
|
|
tracer_spy.append(a)
|
|
return a <= 8
|
|
|
|
def my_body(b):
|
|
tracer_spy.append(b)
|
|
return b + 1
|
|
|
|
return lax.while_loop(my_cond, my_body, x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
0,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
"traced_for=while_body, fun=my_body, arg_names=b, result_paths=result",
|
|
"traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=while_cond, fun=my_cond, arg_names=a, from a",
|
|
"traced_for=while_body, fun=my_body, arg_names=b, from b",
|
|
])
|
|
|
|
def test_fori(self):
|
|
# See https://github.com/jax-ml/jax/issues/23637
|
|
tracer_spy = TracerSpy()
|
|
def my_body(_, c):
|
|
tracer_spy.append(c)
|
|
return 0.
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x: jax.lax.fori_loop(0, 5, my_body, x)),
|
|
3.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
# TODO(necula): bad arg_names, result_paths
|
|
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=result[0][0],result[0][1]",
|
|
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): the arg_names are not right
|
|
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], from loop_carry[1]",
|
|
]
|
|
)
|
|
|
|
tracer_spy = TracerSpy()
|
|
|
|
# When the ubound is not a constant, we use a while_loop
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda ub, x: jax.lax.fori_loop(0, ub, my_body, x)),
|
|
5, 3.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=ub,x, result_paths=result",
|
|
re.compile(r"traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths="),
|
|
# TODO(necula): arg_names and result_paths are not right
|
|
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=result[0],result[1],result[2]",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): the arg_names are not right
|
|
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], from loop_carry[2]",
|
|
])
|
|
|
|
def test_scan(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
def my_scan_body(carry, inp):
|
|
tracer_spy.append(carry)
|
|
return (carry + inp, carry)
|
|
|
|
return lax.scan(my_scan_body, 0, x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
np.arange(8, dtype=np.int32),
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result[0],result[1]",
|
|
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=result[0],result[1]",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry"
|
|
])
|
|
|
|
def test_eval_shape(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
return x
|
|
|
|
_ = jax.eval_shape(my_f, 0)
|
|
self._check_tracers_and_jaxprs(
|
|
lambda: jax.eval_shape(my_f, 0),
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, from x"],
|
|
)
|
|
|
|
def test_vmap_of_nested_jit(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y):
|
|
tracer_spy.append(x)
|
|
|
|
def my_g(u, v):
|
|
tracer_spy.append(u)
|
|
return dict(c=u * v, d=v)
|
|
|
|
return jax.jit(my_g)(y, x)["c"]
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.vmap(my_f)),
|
|
np.ones((8,), dtype=np.float32), np.zeros((8,), dtype=np.float32),
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): missing debug info
|
|
"None",
|
|
"traced_for=jit, fun=my_g, arg_names=u,v, from u"
|
|
])
|
|
|
|
def test_pmap(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
return jnp.sin(x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.pmap(my_f),
|
|
np.ones((jax.device_count(),), dtype=np.float32),
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=result"
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=pmap, fun=my_f, arg_names=x, from x"
|
|
],
|
|
)
|
|
|
|
def test_pmap_with_arg_and_result_names(self):
|
|
tracer_spy = TracerSpy()
|
|
x = np.ones((jax.device_count(),), dtype=np.float32)
|
|
def my_f(x, y, *args, a, **kwargs):
|
|
# y and kwargs[c] is dead
|
|
tracer_spy.append(args[1])
|
|
s = x + a + args[1] + kwargs["d"]
|
|
return dict(u=s, v=x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.pmap(my_f, static_broadcasted_argnums=(0,)),
|
|
1., x, x, x, # x, y, args[0], args[1]
|
|
d=x, a=x, b=x, # kwargs
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]",
|
|
],
|
|
expected_lowering_lines=[
|
|
# TODO(necula): we did not DCE y?
|
|
re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"),
|
|
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"),
|
|
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"),
|
|
]
|
|
)
|
|
|
|
def test_pmap_of_grad(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
return jnp.sin(x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.pmap(jax.grad(my_f)),
|
|
np.ones((jax.device_count(),), dtype=np.float32),
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=result",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): missing debug_info
|
|
'None'
|
|
],
|
|
)
|
|
|
|
def test_jvp_pmap_eager(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y, *args):
|
|
# y is dead, x is static broadcasted
|
|
tracer_spy.append(args[1])
|
|
s = x + args[1]
|
|
return dict(u=s, v=x)
|
|
|
|
x = jnp.ones((jax.device_count(), 1), dtype=np.float32)
|
|
x_tan = jnp.full_like(x, .1)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
lambda x, x_tan: jax.jvp(jax.pmap(my_f),
|
|
(x, x, x, x), (x_tan, x_tan, x_tan, x_tan)),
|
|
x, x_tan,
|
|
expected_jaxpr_debug_infos=[
|
|
# TODO(necula): why this?
|
|
re.compile(r'traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*'),
|
|
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=result['u'],result['v']",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): missing debug_info
|
|
"None"
|
|
],
|
|
)
|
|
|
|
@jtu.ignore_warning(category=UserWarning,
|
|
message=".* jitted function .* includes a pmap")
|
|
def test_jvp_pmap(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x, y):
|
|
tracer_spy.append(x)
|
|
return jnp.sin(x) + y
|
|
|
|
x = np.ones((jax.device_count(), 1), dtype=np.float32)
|
|
x_tan = np.full_like(x, .1)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))),
|
|
x, x_tan,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=result[0],result[1]",
|
|
"traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=result",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): missing debug_info
|
|
"None"
|
|
],
|
|
)
|
|
|
|
def test_hessian(self):
|
|
tracer_spy = TracerSpy()
|
|
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
return jnp.square(x).mean()
|
|
|
|
x = jax.random.uniform(jax.random.key(0), shape=(8, 4))
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.hessian(jax.jit(my_f))),
|
|
x,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
# TODO(necula): arg_names and result_paths?
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
|
|
"traced_for=jit, fun=my_f, arg_names=x,, result_paths=,"
|
|
if config.use_direct_linearize.value else
|
|
"traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
|
|
],
|
|
tracer_spy=tracer_spy,
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, from x",
|
|
],
|
|
)
|
|
|
|
(x).block_until_ready()
|
|
|
|
def test_remat(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
@jax.remat
|
|
def my_g(y):
|
|
tracer_spy.append(y)
|
|
return lax.sin(y)
|
|
|
|
return my_g(x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
0.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
# TODO(necula): missing result_paths
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y"
|
|
])
|
|
|
|
def test_grad_remat(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_f(x):
|
|
@jax.remat
|
|
def my_g(y):
|
|
tracer_spy.append(y)
|
|
return lax.sin(lax.sin(y))
|
|
|
|
return my_g(my_g(x))
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.grad(my_f)),
|
|
0.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
# TODO(necula): arg_names? result_paths?
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y",
|
|
])
|
|
|
|
def test_remat_shard_map(self):
|
|
tracer_spy = TracerSpy()
|
|
if len(jax.devices()) < 2:
|
|
self.skipTest("requires at least 2 devices")
|
|
# this tests remat-of-shmap
|
|
mesh = Mesh(np.array(jax.devices()[:2]), ('x',))
|
|
|
|
# check param updating is handled
|
|
@jax.remat
|
|
@functools.partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
return jnp.sin(jnp.sin(x))
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(jax.grad(lambda x: my_f(x).sum())),
|
|
jnp.arange(2, dtype=np.float32),
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
# TODO(necula): arg_names, result_paths
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
"traced_for=checkpoint / remat, fun=my_f, arg_names=,, result_paths=",
|
|
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result",
|
|
"traced_for=shard_map, fun=my_f, arg_names=,, result_paths=",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"None" # TODO(necula): missing
|
|
])
|
|
|
|
def test_remat_saved_residuals(self):
|
|
@functools.partial(jax.remat,
|
|
static_argnums=(1,),
|
|
policy=lambda p, *_, **__: "mul" in str(p))
|
|
def my_f(x, y):
|
|
x = ad_checkpoint.checkpoint_name(x * x, "foo")
|
|
x = x * x
|
|
return x + y
|
|
|
|
res = saved_residuals(my_f, 3., 4.)
|
|
self.assertEqual(res[0][1], "from the argument x")
|
|
self.assertRegex(res[1][1], r"named 'foo' from .*debug_info_test.py:.*my_f")
|
|
|
|
def test_checkify_pmap_basic(self):
|
|
if len(jax.devices()) < 2:
|
|
self.skipTest("requires at least 2 devices")
|
|
tracer_spy = TracerSpy()
|
|
@jax.pmap
|
|
def my_f(my_x):
|
|
tracer_spy.append(my_x)
|
|
y1 = jnp.sin(1./my_x)
|
|
y2 = jnp.sin(my_x)
|
|
return (y1 + y2,)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(checkify.checkify(my_f, errors=checkify.nan_checks)),
|
|
np.arange(len(jax.devices()), dtype=np.float32),
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
# TODO(necula): this should not be pointing into the JAX internals
|
|
re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"),
|
|
re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"),
|
|
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=result[0]",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=pmap, fun=my_f, arg_names=my_x, from my_x",
|
|
],
|
|
check_lowering=False, # TODO(necula): warning during lowering
|
|
)
|
|
|
|
def test_custom_dce_static_argnums(self):
|
|
tracer_spy = TracerSpy()
|
|
@functools.partial(jax.experimental.custom_dce.custom_dce,
|
|
static_argnums=(0,))
|
|
def my_g(f, x):
|
|
tracer_spy.append(x)
|
|
return f(x), 10 * f(x)
|
|
|
|
@my_g.def_dce
|
|
def my_g_dce(f, used_outs, x): # note: static_argnums are always passed first
|
|
tracer_spy.append(x)
|
|
self.assertTrue(callable(f))
|
|
return [2 * v if used else None
|
|
for used, v in zip(used_outs, my_g(f, x))]
|
|
|
|
def my_f(x):
|
|
return jnp.exp(x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x: my_g(my_f, x)[0]),
|
|
0.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
"traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=result[0],result[1]",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): no leaked tracer from my_g_dce?
|
|
"traced_for=custom_dce, fun=my_g, arg_names=x, from x",
|
|
])
|
|
|
|
def test_custom_dce_consts(self):
|
|
tracer_spy = TracerSpy()
|
|
@jax.experimental.custom_dce.custom_dce
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
return np.eye(1) * jnp.sin(x), jnp.cos(x)
|
|
|
|
@my_f.def_dce
|
|
def my_rule(used_outs, y):
|
|
tracer_spy.append(y)
|
|
return (
|
|
np.full((1, 1), 2.0) * jnp.exp(y) if used_outs[0] else None,
|
|
jnp.sqrt(y) if used_outs[1] else None,
|
|
)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x: my_f(x)[0]),
|
|
np.array(1.1234),
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
|
|
# TODO(necula): bad arg_names (why None), bad result_paths
|
|
'traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]',
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
# TODO(necula): no leaked tracer from my_rule?
|
|
"traced_for=custom_dce, fun=my_f, arg_names=x, from x",
|
|
])
|
|
|
|
def test_custom_linear_solve_complex(self):
|
|
tracer_spy = TracerSpy()
|
|
def solve(a, b):
|
|
tracer_spy.append(a)
|
|
def my_solve(matvec, x):
|
|
tracer_spy.append(x)
|
|
return jsp.linalg.solve(a, x)
|
|
|
|
def my_high_precision_dot(a, b):
|
|
tracer_spy.append(a)
|
|
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
|
|
|
|
def my_tr_solve(matvec, x):
|
|
tracer_spy.append(x)
|
|
return jsp.linalg.solve(a.T, x)
|
|
matvec = functools.partial(my_high_precision_dot, a)
|
|
return lax.custom_linear_solve(matvec, b, my_solve, my_tr_solve)
|
|
|
|
rng = self.rng()
|
|
a = 0.5 * rng.randn(2, 2) + 0.5j * rng.randn(2, 2)
|
|
b = 0.5 * rng.randn(2) + 0.5j * rng.randn(2)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda a, b: jax.jvp(solve, (a, b), (a, b))),
|
|
a, b,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=a,b, result_paths=result[0],result[1]",
|
|
re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths=result"),
|
|
re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths=result"),
|
|
re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths=result"),
|
|
# TODO(necula): why pointers to internal functions, arg_names, result_paths?
|
|
re.compile(r'traced_for=custom_linear_solve solve, fun=<lambda> at .*linalg.py:.*, arg_names=,,x, result_paths='),
|
|
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,,x, result_paths='),
|
|
re.compile(r'traced_for=custom_linear_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,x, result_paths='),
|
|
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,x, result_paths='),
|
|
"traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=,b, result_paths=result",
|
|
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=,x, result_paths=result",
|
|
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=,x, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x, from x",
|
|
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x, from x",
|
|
"None", # TODO(necula): there are missing debug info
|
|
])
|
|
|
|
def test_custom_root_errors(self):
|
|
tracer_spy = TracerSpy()
|
|
def dummy_root_usage(x):
|
|
tracer_spy.append(x)
|
|
def my_f(x):
|
|
tracer_spy.append(x)
|
|
return x - 3.
|
|
def my_solve(f, x):
|
|
tracer_spy.append(x)
|
|
return x
|
|
def my_transpose_solve(f, x):
|
|
tracer_spy.append(x)
|
|
return x
|
|
return lax.custom_root(my_f, 0., my_solve, my_transpose_solve)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(lambda x: jax.jvp(dummy_root_usage, (x,), (0.0,))),
|
|
0.,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0],result[1]",
|
|
# TODO(necula): internal function?
|
|
re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=result\[0\]"),
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=custom_root, fun=my_f, arg_names=x, from x",
|
|
"traced_for=custom_root solve, fun=my_solve, arg_names=x, from x",
|
|
# TODO(necula): from None
|
|
"traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x, from None",
|
|
"None", # TODO(necula): there are missing debug info
|
|
])
|
|
|
|
def test_pallas_call(self):
|
|
tracer_spy = TracerSpy()
|
|
def my_kernel(x_ref, y_ref, o_ref):
|
|
tracer_spy.append(x_ref)
|
|
o_ref[...] = x_ref[...] + y_ref[...]
|
|
|
|
x = np.arange(256 * 16, dtype=np.float32).reshape((256, 16))
|
|
|
|
def my_f(x):
|
|
def my_index_map(i, j):
|
|
tracer_spy.append(i)
|
|
return (i, j)
|
|
|
|
return pl.pallas_call(my_kernel,
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
in_specs=(pl.BlockSpec((128, 8), my_index_map),
|
|
pl.BlockSpec((128, 8), my_index_map)),
|
|
out_specs=pl.BlockSpec((128, 8), my_index_map),
|
|
grid=(pl.cdiv(x.shape[0], 128), pl.cdiv(x.shape[1], 8)),
|
|
name="my_custom_kernel_name")(x, x)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f), x,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
|
|
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=result[0],result[1]",
|
|
# TODO(necula): result_paths?
|
|
"traced_for=pallas_call kernel, fun=my_custom_kernel_name, arg_names=x_ref,y_ref,o_ref, result_paths=",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, from i",
|
|
"traced_for=pallas_call kernel, fun=my_custom_kernel_name, arg_names=x_ref,y_ref,o_ref, from x_ref",
|
|
],
|
|
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
|
|
)
|
|
|
|
def test_checkify_pallas_call(self):
|
|
tracer_spy = TracerSpy()
|
|
def kernel(x_ref, y_ref):
|
|
tracer_spy.append(x_ref)
|
|
y_ref[...] = jnp.log(x_ref[...])
|
|
|
|
def my_f(input):
|
|
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
|
|
pallas_call = pl.pallas_call(kernel,
|
|
out_shape=out_shape,
|
|
name="my_custom_kernel_name")
|
|
checked_call = checkify.checkify(pallas_call,
|
|
errors=checkify.nan_checks)
|
|
return checked_call(input)[1]
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_f),
|
|
jnp.arange(4, dtype=jnp.float32) - 2,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_f, arg_names=input, result_paths=result",
|
|
# TODO(necula): function source location points in JAX internals
|
|
# TODO(necula): arg_names and result_paths are wrong
|
|
re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*.pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="),
|
|
re.compile(r"traced_for=pallas_call index_map, fun=default_index_map at .*.pallas.core.py:.*, arg_names=, result_paths=result\[0\].*"),
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=pallas_call kernel, fun=my_custom_kernel_name, arg_names=x_ref,y_ref, from x_ref",
|
|
],
|
|
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
|
|
)
|
|
|
|
def test_composite(self):
|
|
tracer_spy = TracerSpy()
|
|
scale = np.array([0.5, 0.4, 0.3], dtype=np.float32)
|
|
@functools.partial(lax.composite, name="my.consts")
|
|
def my_consts(x):
|
|
tracer_spy.append(x)
|
|
return x / scale
|
|
|
|
|
|
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
|
|
|
|
self._check_tracers_and_jaxprs(
|
|
jax.jit(my_consts), x,
|
|
tracer_spy=tracer_spy,
|
|
expected_jaxpr_debug_infos=[
|
|
"traced_for=jit, fun=my_consts, arg_names=x, result_paths=result",
|
|
"traced_for=composite, fun=my_consts, arg_names=x, result_paths=result",
|
|
],
|
|
expected_tracer_debug_infos=[
|
|
"traced_for=composite, fun=my_consts, arg_names=x, from x"])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|