Expanded the error messages due to re-using tracers saved in global state.

Previously these errors were raising Exception (as other internal errors),
but these errors may arise out of mis-use of tracers.
This commit is contained in:
George Necula 2020-02-15 06:35:49 +01:00
parent 938336e08a
commit deb21ef15d
4 changed files with 86 additions and 14 deletions

View File

@ -280,6 +280,13 @@ class Trace(object):
self.level = master.level
self.sublevel = sublevel
def escaped_tracer_error(self, detail):
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
"through global state from a previously traced function.\n"
"The functions being transformed should not save traced values to "
"global state.\nDetails: {}.")
raise ValueError(msg.format(detail))
def full_raise(self, val):
if not isinstance(val, Tracer):
return self.pure(val)
@ -291,19 +298,18 @@ class Trace(object):
elif val._trace.sublevel < sublevel:
return self.sublift(val)
else:
raise Exception("Can't lift sublevels {} to {}"
.format(val._trace.sublevel, sublevel))
self.escaped_tracer_error(
"Can't lift sublevels {} to {}".format(val._trace.sublevel, sublevel))
elif val._trace.level < level:
if val._trace.sublevel > sublevel:
raise Exception("Incompatible sublevel: {}, {}"
.format(val._trace, (level, sublevel)))
self.escaped_tracer_error(
"Incompatible sublevel: {}, {}".format(val._trace, (level, sublevel)))
return self.lift(val)
elif val._trace.level > level:
raise Exception("Can't lift {} to {}".format(val, self))
elif val._trace.level == self.level:
raise Exception("Different traces at same level: {}, {}".format(val, self))
else:
raise Exception("Can't lift {} to {}".format(val, self))
self.escaped_tracer_error(
"Can't lift level {} to {}".format(val, self))
else: # val._trace.level == self.level:
self.escaped_tracer_error("Different traces at same level: {}, {}".format(val, self))
def pure(self, val):

View File

@ -370,8 +370,8 @@ class _BodyTracer(object):
in_tracers=in_tracers,
out_tracers=body_out_tracers,
trace=self.trace)
except AssertionError as e:
if "Encountered unexpected tracer" == str(e):
except ValueError as e:
if "Tracer not among input tracers" in str(e):
raise ValueError("Body of cond_range or while_range should not use the "
"index variable returned by iterator.")
raise

View File

@ -283,8 +283,9 @@ class JaxprTracer(Tracer):
def __init__(self, trace, pval, recipe):
assert isinstance(pval, PartialVal)
pv, const = pval
if isinstance(const, Tracer):
assert const._trace.level < trace.level
if isinstance(const, Tracer) and const._trace.level >= trace.level:
trace.escaped_tracer_error(
"Tracer from a higher level: {} in trace {}".format(const, trace))
self._trace = trace
self.pval = pval
self.recipe = recipe
@ -438,7 +439,8 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
eqns.append(recipe_to_eqn(newvar, getvar, recipe))
processed_eqn_ids.add(recipe.eqn_id)
elif isinstance(recipe, LambdaBinding):
assert any(t is in_tracer for in_tracer in in_tracers), "Encountered unexpected tracer"
if not any(t is in_tracer for in_tracer in in_tracers):
t._trace.escaped_tracer_error("Tracer not among input tracers {}".format(t))
assert in_tracers, "Lambda binding with no args"
elif isinstance(recipe, FreeVar):
env[getvar(t)] = recipe.val

View File

@ -17,6 +17,7 @@ import collections
from contextlib import contextmanager
import copy
from functools import partial
import re
import unittest
import warnings
import weakref
@ -1762,6 +1763,69 @@ class APITest(jtu.JaxTestCase):
jit(fun)(np.array([0, 1, 2], dtype=np.int32)) # doesn't crash
def helper_save_tracer(self, x):
self._saved_tracer = x
return x
def test_escaped_tracers_diffent_top_level_traces(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
ValueError,
re.compile(
"Encountered an unexpected tracer.*Different traces at same level",
re.DOTALL)):
api.jit(lambda x: self._saved_tracer)(0.)
def test_escaped_tracers_cant_lift_sublevels(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
ValueError,
re.compile(
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
re.DOTALL)):
api.jit(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_tracer_from_higher_level(self):
api.grad(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
ValueError,
re.compile(
"Encountered an unexpected tracer.*Tracer from a higher level",
re.DOTALL)):
api.grad(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_incompatible_sublevel(self):
def func1(x):
api.jit(self.helper_save_tracer)(0.)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
ValueError,
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
re.DOTALL)):
api.jit(func1)(2.)
def test_escaped_tracers_cant_lift(self):
def func1(x):
api.grad(self.helper_save_tracer)(0.)
return x + self._saved_tracer
with self.assertRaisesRegex(
ValueError, re.compile("Encountered an unexpected tracer.*Can't lift",
re.DOTALL)):
api.grad(func1)(2.)
def test_escaped_tracers_not_among_input_tracers(self):
def func1(x):
api.grad(self.helper_save_tracer)(x)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
ValueError, re.compile(
"Encountered an unexpected tracer.*Tracer not among input tracers",
re.DOTALL)):
api.jit(func1)(2.)
class JaxprTest(jtu.JaxTestCase):