Merge pull request #25988 from gnecula:debug_info_tests

PiperOrigin-RevId: 717592689
This commit is contained in:
jax authors 2025-01-20 11:57:35 -08:00
commit e41f4caa3e
5 changed files with 478 additions and 287 deletions

View File

@ -38,6 +38,12 @@ jax_multiplatform_test(
shard_count = 10,
)
jax_multiplatform_test(
name = "debug_info_test",
srcs = ["debug_info_test.py"],
enable_configs = ["tpu_v3_2x2"],
)
jax_multiplatform_test(
name = "device_test",
srcs = ["device_test.py"],

View File

@ -1322,75 +1322,6 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertIsNotNone(f.runtime_executable())
self.assertIsNotNone(g.runtime_executable())
def test_jit_lower_arg_info(self):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)
@parameterized.parameters([0, 2, [(0, 2)]])
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f, static_argnums=static_argnums).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)
@parameterized.parameters(['a', 'b', [('a', 'b')]])
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
lowered = jax.jit(f, static_argnames=static_argnames).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
self.assertNotIn("kwargs['a']", hlo_str)
self.assertNotIn("kwargs['b']", hlo_str)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in (
"\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']",
"kwargs['w']", "kwargs['a']", "kwargs['b']"
):
self.assertNotIn(s, hlo_str)
def test_jit_lower_result_info(self):
def f(x, y, z):
return {'a': x, 'b': [y]}
hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def test_jit_lower_compile_with_compiler_options(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
@ -3598,21 +3529,6 @@ class APITest(jtu.JaxTestCase):
# level, which is no longer live.
jax.jit(jnp.add)(jnp.ones(()), count)
def test_escaped_tracer_transform_name(self):
with self.assertRaisesRegex(UnexpectedTracerError,
"for jit"):
jax.jit(self.helper_save_tracer)(1)
_ = self._saved_tracer+1
with self.assertRaisesRegex(UnexpectedTracerError,
"for pmap"):
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
_ = self._saved_tracer+1
with self.assertRaisesRegex(UnexpectedTracerError,
"for jit"):
jax.eval_shape(self.helper_save_tracer, 1)
_ = self._saved_tracer+1
def test_escaped_tracer_shape_dtype(self):
with self.assertRaisesRegex(core.UnexpectedTracerError, r"int32\[4,3\]"):
@ -3659,120 +3575,6 @@ class APITest(jtu.JaxTestCase):
f() # doesn't crash
def test_concrete_error_because_arg_unary(self):
@jax.jit
def f(x):
if x > 0:
return x
else:
return 0
msg = r"on the value of the argument x"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1)
def test_concrete_error_because_arg_binary(self):
@jax.jit
def f(x, y):
if x > y:
return x
else:
return y
msg = r"on the values of the arguments x and y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)
def test_concrete_error_because_arg_ternary(self):
@jax.jit
def f(x, y, z):
if x > z:
return x
else:
return y
msg = r"on the values of the arguments x and z"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, z=3)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, y=2, z=3)
def test_concrete_error_because_arg_varargs(self):
@jax.jit
def f(*args):
x, y, z = args
if x > z:
return x
else:
return y
msg = r"on the values of the arguments args"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
def test_concrete_error_because_arg_kwargs(self):
@jax.jit
def f(**kwargs):
x, y, z = kwargs['x'], kwargs['y'], kwargs['z']
if x > z:
return x
else:
return y
msg = r"on the values of the arguments kwargs"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(x=1, y=2, z=3)
def test_concrete_error_because_arg_pytree(self):
@jax.jit
def f(xy, z):
x, y = xy
if x > 0:
return x
else:
return y
msg = r"on the value of the argument xy"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f((1, 2), z=3)
def test_concrete_error_because_const(self):
@jax.jit
def f():
assert jnp.add(1, 1) > 0
msg = "on these lines"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()
def test_concrete_error_because_const_2(self):
@jax.jit
def f():
result = sum(jnp.add(1, 1) for _ in range(6))
assert result > 0
msg = "Additional originating lines are not shown."
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()
def test_concrete_error_with_nested_call(self):
@jax.jit
def f(x, y):
if y:
return x
@jax.jit
def g(x):
return f(x, True)
msg = r"on the value of the argument y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
def test_linearize_aux(self):
def fn(x):
return x * 2 - 3, x > 0
@ -4940,39 +4742,6 @@ class RematTest(jtu.JaxTestCase):
expected = f_lin_expected(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_concrete_error(self):
@jax.remat # no static_argnums or concrete
def g(x):
if x > 0:
return lax.sin(x)
else:
return lax.cos(x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(3.)
@partial(jax.remat, static_argnums=(0,)) # using static_argnums but...
def g(x):
if x > 0: # jnp operations still get staged!
return lax.sin(x)
else:
return lax.cos(x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(jnp.array(3.))
# But don't raise an error mentioning static_argnums here:
@jax.remat
def g(x):
jax.jit(lambda: 0 if jnp.add(1, 1) else 0)()
return lax.sin(x)
try:
g(jnp.array(3.))
except core.ConcretizationTypeError as e:
msg = str(e)
self.assertNotIn('static_argnums', msg)
@unittest.skip
def test_remat_grad_python_control_flow_static_argnums(self):
@partial(jax.remat, static_argnums=(0,))

472
tests/debug_info_test.py Normal file
View File

@ -0,0 +1,472 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import functools
import math
import operator
from absl.testing import absltest, parameterized
import jax
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src.compilation_cache import is_persistent_cache_enabled
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
from jax.errors import UnexpectedTracerError
import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
class DebugInfoTest(jtu.JaxTestCase):
def helper_save_tracer(self, x):
self._saved_tracer = x
return x
def test_jit_lower_arg_info(self):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)
@parameterized.parameters([0, 2, [(0, 2)]])
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f, static_argnums=static_argnums).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)
@parameterized.parameters(['a', 'b', [('a', 'b')]])
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
lowered = jax.jit(f, static_argnames=static_argnames).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.)
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
self.assertNotIn("kwargs['a']", hlo_str)
self.assertNotIn("kwargs['b']", hlo_str)
hlo_str = lowered.as_text("stablehlo", debug_info=False)
for s in (
"\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']",
"kwargs['w']", "kwargs['a']", "kwargs['b']"
):
self.assertNotIn(s, hlo_str)
def test_jit_lower_result_info(self):
def f(x, y, z):
return {'a': x, 'b': [y]}
hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def test_jit_lower_compile_arg_type_mismatch(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
x = jnp.array(1, dtype=int)
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
f_exe = jax.jit(f).lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
def test_jit_bad_input(self):
def f(x):
return x
err_str = ("Error interpreting argument to .* as an abstract array. The problematic "
"value is of type .* and was passed to the function at path x.")
with self.assertRaisesRegex(TypeError, err_str):
jax.jit(f)("foo")
# Jax type objects aren't valid data arguments.
with self.assertRaisesRegex(TypeError, err_str):
jax.jit(f)(jnp.int32)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations(self):
@jax.jit
def f(x, y):
return jnp.sin(x) * y['hi']
x = jnp.float32(1.)
y = {'hi': jnp.arange(3., dtype='float32')}
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
# print on first miss, not on hit
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(x, y)
f(x, y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('TRACING CACHE MISS', msg)
self.assertIn('never seen function', msg)
# shape change
y_ = {'hi': jnp.arange(4, dtype='float32')}
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(x, y_)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen input type signature', msg)
self.assertIn('closest seen input type signature has 1 mismatches', msg)
self.assertIn('seen f32[3], but now given f32[4]', msg)
# weak type change (assuming no x64)
if not config.enable_x64.value:
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(1., y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('weak_type=True', msg)
self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg)
# kwarg change
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(1, y=y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
# tracing config change
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
with jax.numpy_rank_promotion('warn'):
f(x, y)
# depending on the backend, we may or may not get persistent cache warnings
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
msg = cm.output[0]
self.assertIn("tracing context doesn't match", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_new_function_in_loop(self):
@jax.jit
def f(x, y):
return jnp.sin(x) * y['hi']
x = jnp.float32(1.)
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
for _ in range(2):
jax.jit(lambda x: 2 * x)(3)
if is_persistent_cache_enabled():
# number of warnings depends on the backend
self.assertTrue(4 <= len(cm.output) <= 6)
msg = cm.output[3]
self.assertIn('another function defined on the same line', msg)
else:
self.assertLen(cm.output, 2)
_, msg = cm.output
self.assertIn('another function defined on the same line', msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_unpacks_transforms(self):
# Tests that the explain_tracing_cache_miss() function does not throw an
# error when unpacking `transforms` with a length greater than 3.
@jax.jit
def f(key):
return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32)
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
f(jax.random.key(seed=123))
if is_persistent_cache_enabled():
# 5 warnings from tracing cache, 5-10 from persistent cache depending on
# the backend
self.assertTrue(10 <= len(cm.output) <= 15)
self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output))
else:
self.assertLen(cm.output, 5)
for msg in cm.output:
self.assertIn("TRACING CACHE MISS", msg)
def test_cache_miss_explanations_no_source_info(self):
# ``operator.add`` is a built-in function and does not have source info.
with config.explain_cache_misses(True):
jax.jit(operator.add)(42, 24)
def test_concrete_error_because_arg_unary(self):
@jax.jit
def f(x):
if x > 0:
return x
else:
return 0
msg = r"on the value of the argument x"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1)
def test_concrete_error_because_arg_binary(self):
@jax.jit
def f(x, y):
if x > y:
return x
else:
return y
msg = r"on the values of the arguments x and y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2)
def test_concrete_error_because_arg_ternary(self):
@jax.jit
def f(x, y, z):
if x > z:
return x
else:
return y
msg = r"on the values of the arguments x and z"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, z=3)
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, y=2, z=3)
def test_concrete_error_because_arg_varargs(self):
@jax.jit
def f(*args):
x, y, z = args
if x > z:
return x
else:
return y
msg = r"on the values of the arguments args"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(1, 2, 3)
def test_concrete_error_because_arg_kwargs(self):
@jax.jit
def f(**kwargs):
x, y, z = kwargs['x'], kwargs['y'], kwargs['z']
if x > z:
return x
else:
return y
msg = r"on the values of the arguments kwargs"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f(x=1, y=2, z=3)
def test_concrete_error_because_arg_pytree(self):
@jax.jit
def f(xy, z):
x, y = xy
if x > 0:
return x
else:
return y
msg = r"on the value of the argument xy"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f((1, 2), z=3)
def test_concrete_error_because_const(self):
@jax.jit
def f():
assert jnp.add(1, 1) > 0
msg = "on these lines"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()
def test_concrete_error_because_const_2(self):
@jax.jit
def f():
result = sum(jnp.add(1, 1) for _ in range(6))
assert result > 0
msg = "Additional originating lines are not shown."
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
f()
def test_concrete_error_with_nested_call(self):
@jax.jit
def f(x, y):
if y:
return x
@jax.jit
def g(x):
return f(x, True)
msg = r"on the value of the argument y"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
def test_escaped_tracer_transform_name(self):
with self.assertRaisesRegex(UnexpectedTracerError,
"for jit"):
jax.jit(self.helper_save_tracer)(1)
_ = self._saved_tracer+1
with self.assertRaisesRegex(UnexpectedTracerError,
"for pmap"):
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
_ = self._saved_tracer+1
with self.assertRaisesRegex(UnexpectedTracerError,
"for jit"):
jax.eval_shape(self.helper_save_tracer, 1)
_ = self._saved_tracer+1
def test_remat_concrete_error(self):
@jax.remat # no static_argnums or concrete
def g(x):
if x > 0:
return lax.sin(x)
else:
return lax.cos(x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(3.)
@functools.partial(jax.remat, static_argnums=(0,)) # using static_argnums but...
def g(x):
if x > 0: # jnp operations still get staged!
return lax.sin(x)
else:
return lax.cos(x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "static_argnums"):
g(jnp.array(3.))
# But don't raise an error mentioning static_argnums here:
@jax.remat
def g(x):
jax.jit(lambda: 0 if jnp.add(1, 1) else 0)()
return lax.sin(x)
try:
g(jnp.array(3.))
except core.ConcretizationTypeError as e:
msg = str(e)
self.assertNotIn('static_argnums', msg)
class EagerPmapMixin:
def setUp(self):
super().setUp()
stack = contextlib.ExitStack()
stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
stack.enter_context(jtu.ignore_warning(
message="Some donated buffers were not usable", category=UserWarning))
self.addCleanup(stack.close)
@jtu.pytest_mark_if_available('multiaccelerator')
class PythonPmapEagerTest(EagerPmapMixin, jtu.JaxTestCase):
def test_pmap_lower_arg_info(self):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.pmap(f).lower(
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
def test_pmap_lower_result_info(self):
def f(x, y, z):
return {'a': x, 'b': [y]}
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
[jnp.array([3])])
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def testLowerCompileArgTypeMismatch(self):
f = jax.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=int).reshape(shape)
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -29,7 +29,6 @@ import numpy as np
import jax
from jax._src import core
from jax import dtypes
from jax.errors import UnexpectedTracerError
from jax import lax
from jax import random
from jax._src import test_util as jtu
@ -589,20 +588,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
init = jnp.float32(10)
self.assertEqual(fori_loop_with_static_upper_and_lower(init), init)
def test_fori_error_points_to_user_code(self):
# See https://github.com/jax-ml/jax/issues/23637
def my_body(_, c):
return bool(c)
with self.assertRaisesRegex(
jax.errors.TracerBoolConversionError,
"occurred while tracing the function my_body at .*control_flow_test.py.* for scan"):
jax.lax.fori_loop(0, 5, my_body, 3.)
with self.assertRaisesRegex(
jax.errors.TracerBoolConversionError,
"occurred while tracing the function my_body at .*control_flow_test.py.* for while_loop"):
jax.jit(lambda ubound: jax.lax.fori_loop(0, ubound, my_body, 3.))(5)
def testForiLoopBatched(self):
def body_fun(i, loop_carry):
@ -2750,22 +2735,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)
def test_unexpected_tracer_error(self):
with self.assertRaisesRegex(UnexpectedTracerError, "for while_loop"):
lst = []
def side_effecting_body(val):
lst.append(val)
return val+1
lax.while_loop(lambda x: x < 2, side_effecting_body, 1)
lst[0] += 1
with self.assertRaisesRegex(UnexpectedTracerError, "for scan"):
lst = []
def side_effecting_scan(carry, val):
lst.append(val)
return carry, val+1
lax.scan(side_effecting_scan, None, jnp.ones((2, 2)))
lst[0] += 1
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
def f(i, x):

View File

@ -2119,31 +2119,6 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)
def test_pmap_lower_arg_info(self):
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.pmap(f).lower(
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
def test_pmap_lower_result_info(self):
def f(x, y, z):
return {'a': x, 'b': [y]}
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
[jnp.array([3])])
hlo_str = lowered.as_text("stablehlo", debug_info=True)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def test_axis_name_shadowing_with_vmap(self):
# vmap-of-pmap with mismatched axis sizes
jax.vmap(jax.pmap(lambda x: 2 * x, axis_name='i'),