Merge pull request #6917 from hawkinsp:token

PiperOrigin-RevId: 377997455
This commit is contained in:
jax authors 2021-06-07 13:50:56 -07:00
commit 7591e16c98
2 changed files with 11 additions and 0 deletions

View File

@ -1043,6 +1043,7 @@ core.pytype_aval_mappings[Token] = lambda _: abstract_token
xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),)
xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token
canonicalize_dtype_handlers[Token] = identity
device_put_handlers[Token] = lambda x, _: (x,)
def _forward_method(attrname, self, fun, *args):

View File

@ -369,6 +369,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIs(z3, x1)
self.assertEqual(z2, 1)
def test_trivial_computations_with_tokens(self):
@self.jit
def noop(arr, token):
return arr, token
arr = jax.numpy.ones(10)
token = jax.lax.create_token()
self.assertEqual(token, noop(arr, token)[1])
def test_jit_bad_input(self):
def f(x):
return x