mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix exception when tokens are used in AD.
This commit is contained in:
parent
9ccfc9fd48
commit
dd34d48fd1
@ -1134,6 +1134,7 @@ class AbstractToken(AbstractValue):
|
||||
else:
|
||||
assert False, f"Cannot join {self} with {other}"
|
||||
def str_short(self): return 'Tok'
|
||||
def at_least_vspace(self): return self
|
||||
|
||||
abstract_token = AbstractToken()
|
||||
|
||||
|
@ -367,10 +367,9 @@ def _execute_replicated_primitive(prim, compiled, result_handler, *args):
|
||||
|
||||
|
||||
def check_special(prim, bufs):
|
||||
for buf in bufs:
|
||||
# TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
|
||||
# the default.
|
||||
_check_special(prim.name, getattr(buf, "xla_shape", buf.shape)(), buf)
|
||||
if FLAGS.jax_debug_infs or FLAGS.jax_debug_nans:
|
||||
for buf in bufs:
|
||||
_check_special(prim.name, buf.xla_shape(), buf)
|
||||
|
||||
def _check_special(name, xla_shape, buf):
|
||||
assert not xla_shape.is_tuple()
|
||||
|
@ -2160,6 +2160,20 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "tangent values inconsistent"):
|
||||
f_jvp(np.ones(2, np.int32))
|
||||
|
||||
def test_grad_of_token_consuming_primitive(self):
|
||||
# https://github.com/google/jax/issues/5463
|
||||
tokentest_p = core.Primitive("tokentest")
|
||||
tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p))
|
||||
tokentest_p.def_abstract_eval(lambda x, y: x)
|
||||
xla.translations[tokentest_p] = lambda c, x, y: x
|
||||
ad.defjvp(tokentest_p, (lambda g, x, token: x), None)
|
||||
|
||||
token = jax.lax.create_token(123)
|
||||
arr = jnp.ones((3, 2))
|
||||
res, vjp_fun = jax.vjp(lambda x: tokentest_p.bind(x, token), arr)
|
||||
# Should not crash.
|
||||
vjp_fun(arr)
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user