mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Fix cuda custom call example to build with updated XLA FFI API.
PiperOrigin-RevId: 650977379
This commit is contained in:
parent
e4b606e38a
commit
ebfbd8ac0c
@ -53,7 +53,8 @@ ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a,
|
||||
const int grid_dim = 1;
|
||||
// Note how we access regular Buffer data vs Result Buffer data:
|
||||
FooFwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
|
||||
a.data, b.data, c->data, b_plus_1->data, n);
|
||||
a.typed_data(), b.typed_data(), c->typed_data(), b_plus_1->typed_data(),
|
||||
n);
|
||||
// Check for launch time errors. Note that this function may also
|
||||
// return error codes from previous, asynchronous launches. This
|
||||
// means that an error status returned here could have been caused
|
||||
@ -113,7 +114,8 @@ ffi::Error FooBwdHost(cudaStream_t stream,
|
||||
const int block_dim = 128;
|
||||
const int grid_dim = 1;
|
||||
FooBwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
|
||||
c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n);
|
||||
c_grad.typed_data(), a.typed_data(), b_plus_1->typed_data(),
|
||||
a_grad->typed_data(), b_grad->typed_data(), n);
|
||||
cudaError_t last_error = cudaGetLastError();
|
||||
if (last_error != cudaSuccess) {
|
||||
return ffi::Error(
|
||||
|
Loading…
x
Reference in New Issue
Block a user