rm CmdBuffer traits

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
This commit is contained in:
Phuong Nguyen 2024-10-16 10:27:09 -07:00
parent f3775aa233
commit 82113cd047
2 changed files with 2 additions and 5 deletions

View File

@ -27,7 +27,6 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax.extend import ffi
from jax.lib import xla_client
# start test boilerplate
from absl.testing import absltest

View File

@ -77,8 +77,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Attr<size_t>("n"),
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
.Attr<size_t>("n"));
//----------------------------------------------------------------------------//
// Backward pass //
@ -136,5 +135,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
.Attr<size_t>("n"),
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
.Attr<size_t>("n"));