added cmdBuffer traits

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
This commit is contained in:
Phuong Nguyen 2024-10-16 10:37:49 -07:00
parent bb271aaff8
commit d4bbb4fd84

View File

@ -77,7 +77,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Attr<size_t>("n"));
.Attr<size_t>("n"),
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
//----------------------------------------------------------------------------//
// Backward pass //
@ -135,4 +136,5 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
.Attr<size_t>("n"));
.Attr<size_t>("n"),
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled