mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
added cmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
This commit is contained in:
parent
bb271aaff8
commit
d4bbb4fd84
@ -77,7 +77,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Attr<size_t>("n"));
|
||||
.Attr<size_t>("n"),
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Backward pass //
|
||||
@ -135,4 +136,5 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
|
||||
.Attr<size_t>("n"));
|
||||
.Attr<size_t>("n"),
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
|
Loading…
x
Reference in New Issue
Block a user