mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Expanded the error messages due to re-using tracers saved in global state.
Previously these errors were raising Exception (as other internal errors), but these errors may arise out of mis-use of tracers.
This commit is contained in:
parent
938336e08a
commit
deb21ef15d
24
jax/core.py
24
jax/core.py
@ -280,6 +280,13 @@ class Trace(object):
|
||||
self.level = master.level
|
||||
self.sublevel = sublevel
|
||||
|
||||
def escaped_tracer_error(self, detail):
|
||||
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
|
||||
"through global state from a previously traced function.\n"
|
||||
"The functions being transformed should not save traced values to "
|
||||
"global state.\nDetails: {}.")
|
||||
raise ValueError(msg.format(detail))
|
||||
|
||||
def full_raise(self, val):
|
||||
if not isinstance(val, Tracer):
|
||||
return self.pure(val)
|
||||
@ -291,19 +298,18 @@ class Trace(object):
|
||||
elif val._trace.sublevel < sublevel:
|
||||
return self.sublift(val)
|
||||
else:
|
||||
raise Exception("Can't lift sublevels {} to {}"
|
||||
.format(val._trace.sublevel, sublevel))
|
||||
self.escaped_tracer_error(
|
||||
"Can't lift sublevels {} to {}".format(val._trace.sublevel, sublevel))
|
||||
elif val._trace.level < level:
|
||||
if val._trace.sublevel > sublevel:
|
||||
raise Exception("Incompatible sublevel: {}, {}"
|
||||
.format(val._trace, (level, sublevel)))
|
||||
self.escaped_tracer_error(
|
||||
"Incompatible sublevel: {}, {}".format(val._trace, (level, sublevel)))
|
||||
return self.lift(val)
|
||||
elif val._trace.level > level:
|
||||
raise Exception("Can't lift {} to {}".format(val, self))
|
||||
elif val._trace.level == self.level:
|
||||
raise Exception("Different traces at same level: {}, {}".format(val, self))
|
||||
else:
|
||||
raise Exception("Can't lift {} to {}".format(val, self))
|
||||
self.escaped_tracer_error(
|
||||
"Can't lift level {} to {}".format(val, self))
|
||||
else: # val._trace.level == self.level:
|
||||
self.escaped_tracer_error("Different traces at same level: {}, {}".format(val, self))
|
||||
|
||||
|
||||
def pure(self, val):
|
||||
|
@ -370,8 +370,8 @@ class _BodyTracer(object):
|
||||
in_tracers=in_tracers,
|
||||
out_tracers=body_out_tracers,
|
||||
trace=self.trace)
|
||||
except AssertionError as e:
|
||||
if "Encountered unexpected tracer" == str(e):
|
||||
except ValueError as e:
|
||||
if "Tracer not among input tracers" in str(e):
|
||||
raise ValueError("Body of cond_range or while_range should not use the "
|
||||
"index variable returned by iterator.")
|
||||
raise
|
||||
|
@ -283,8 +283,9 @@ class JaxprTracer(Tracer):
|
||||
def __init__(self, trace, pval, recipe):
|
||||
assert isinstance(pval, PartialVal)
|
||||
pv, const = pval
|
||||
if isinstance(const, Tracer):
|
||||
assert const._trace.level < trace.level
|
||||
if isinstance(const, Tracer) and const._trace.level >= trace.level:
|
||||
trace.escaped_tracer_error(
|
||||
"Tracer from a higher level: {} in trace {}".format(const, trace))
|
||||
self._trace = trace
|
||||
self.pval = pval
|
||||
self.recipe = recipe
|
||||
@ -438,7 +439,8 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
|
||||
eqns.append(recipe_to_eqn(newvar, getvar, recipe))
|
||||
processed_eqn_ids.add(recipe.eqn_id)
|
||||
elif isinstance(recipe, LambdaBinding):
|
||||
assert any(t is in_tracer for in_tracer in in_tracers), "Encountered unexpected tracer"
|
||||
if not any(t is in_tracer for in_tracer in in_tracers):
|
||||
t._trace.escaped_tracer_error("Tracer not among input tracers {}".format(t))
|
||||
assert in_tracers, "Lambda binding with no args"
|
||||
elif isinstance(recipe, FreeVar):
|
||||
env[getvar(t)] = recipe.val
|
||||
|
@ -17,6 +17,7 @@ import collections
|
||||
from contextlib import contextmanager
|
||||
import copy
|
||||
from functools import partial
|
||||
import re
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
@ -1762,6 +1763,69 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
jit(fun)(np.array([0, 1, 2], dtype=np.int32)) # doesn't crash
|
||||
|
||||
def helper_save_tracer(self, x):
|
||||
self._saved_tracer = x
|
||||
return x
|
||||
|
||||
def test_escaped_tracers_diffent_top_level_traces(self):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Different traces at same level",
|
||||
re.DOTALL)):
|
||||
api.jit(lambda x: self._saved_tracer)(0.)
|
||||
|
||||
def test_escaped_tracers_cant_lift_sublevels(self):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
|
||||
re.DOTALL)):
|
||||
api.jit(lambda x: x)(self._saved_tracer)
|
||||
|
||||
def test_escaped_tracers_tracer_from_higher_level(self):
|
||||
api.grad(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Tracer from a higher level",
|
||||
re.DOTALL)):
|
||||
api.grad(lambda x: x)(self._saved_tracer)
|
||||
|
||||
def test_escaped_tracers_incompatible_sublevel(self):
|
||||
def func1(x):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
# Use the tracer
|
||||
return x + self._saved_tracer
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
def test_escaped_tracers_cant_lift(self):
|
||||
def func1(x):
|
||||
api.grad(self.helper_save_tracer)(0.)
|
||||
return x + self._saved_tracer
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, re.compile("Encountered an unexpected tracer.*Can't lift",
|
||||
re.DOTALL)):
|
||||
api.grad(func1)(2.)
|
||||
|
||||
def test_escaped_tracers_not_among_input_tracers(self):
|
||||
def func1(x):
|
||||
api.grad(self.helper_save_tracer)(x)
|
||||
# Use the tracer
|
||||
return x + self._saved_tracer
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, re.compile(
|
||||
"Encountered an unexpected tracer.*Tracer not among input tracers",
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user