Fix exception when tokens are used in AD.

This commit is contained in:
Peter Hawkins 2021-01-22 10:57:33 -05:00
parent 9ccfc9fd48
commit dd34d48fd1
3 changed files with 18 additions and 4 deletions

View File

@ -1134,6 +1134,7 @@ class AbstractToken(AbstractValue):
else:
assert False, f"Cannot join {self} with {other}"
def str_short(self): return 'Tok'
def at_least_vspace(self): return self
abstract_token = AbstractToken()

View File

@ -367,10 +367,9 @@ def _execute_replicated_primitive(prim, compiled, result_handler, *args):
def check_special(prim, bufs):
for buf in bufs:
# TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
# the default.
_check_special(prim.name, getattr(buf, "xla_shape", buf.shape)(), buf)
if FLAGS.jax_debug_infs or FLAGS.jax_debug_nans:
for buf in bufs:
_check_special(prim.name, buf.xla_shape(), buf)
def _check_special(name, xla_shape, buf):
assert not xla_shape.is_tuple()

View File

@ -2160,6 +2160,20 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "tangent values inconsistent"):
f_jvp(np.ones(2, np.int32))
def test_grad_of_token_consuming_primitive(self):
# https://github.com/google/jax/issues/5463
tokentest_p = core.Primitive("tokentest")
tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p))
tokentest_p.def_abstract_eval(lambda x, y: x)
xla.translations[tokentest_p] = lambda c, x, y: x
ad.defjvp(tokentest_p, (lambda g, x, token: x), None)
token = jax.lax.create_token(123)
arr = jnp.ones((3, 2))
res, vjp_fun = jax.vjp(lambda x: tokentest_p.bind(x, token), arr)
# Should not crash.
vjp_fun(arr)
class RematTest(jtu.JaxTestCase):