mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6917 from hawkinsp:token
PiperOrigin-RevId: 377997455
This commit is contained in:
commit
7591e16c98
@ -1043,6 +1043,7 @@ core.pytype_aval_mappings[Token] = lambda _: abstract_token
|
||||
xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||
xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token
|
||||
canonicalize_dtype_handlers[Token] = identity
|
||||
device_put_handlers[Token] = lambda x, _: (x,)
|
||||
|
||||
|
||||
def _forward_method(attrname, self, fun, *args):
|
||||
|
@ -369,6 +369,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIs(z3, x1)
|
||||
self.assertEqual(z2, 1)
|
||||
|
||||
def test_trivial_computations_with_tokens(self):
|
||||
@self.jit
|
||||
def noop(arr, token):
|
||||
return arr, token
|
||||
|
||||
arr = jax.numpy.ones(10)
|
||||
token = jax.lax.create_token()
|
||||
|
||||
self.assertEqual(token, noop(arr, token)[1])
|
||||
|
||||
def test_jit_bad_input(self):
|
||||
def f(x):
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user