mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
rm CmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
This commit is contained in:
parent
f3775aa233
commit
82113cd047
@ -27,7 +27,6 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.extend import ffi
|
||||
from jax.lib import xla_client
|
||||
|
||||
# start test boilerplate
|
||||
from absl.testing import absltest
|
||||
|
@ -77,8 +77,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Attr<size_t>("n"),
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
.Attr<size_t>("n"));
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Backward pass //
|
||||
@ -136,5 +135,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
|
||||
.Attr<size_t>("n"),
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
.Attr<size_t>("n"));
|
||||
|
Loading…
x
Reference in New Issue
Block a user