Fix cuda custom call example to build with updated XLA FFI API.

PiperOrigin-RevId: 650977379
This commit is contained in:
Tom Ward 2024-07-10 05:29:11 -07:00 committed by jax authors
parent e4b606e38a
commit ebfbd8ac0c

View File

@ -53,7 +53,8 @@ ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a,
const int grid_dim = 1;
// Note how we access regular Buffer data vs Result Buffer data:
FooFwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
a.data, b.data, c->data, b_plus_1->data, n);
a.typed_data(), b.typed_data(), c->typed_data(), b_plus_1->typed_data(),
n);
// Check for launch time errors. Note that this function may also
// return error codes from previous, asynchronous launches. This
// means that an error status returned here could have been caused
@ -113,7 +114,8 @@ ffi::Error FooBwdHost(cudaStream_t stream,
const int block_dim = 128;
const int grid_dim = 1;
FooBwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n);
c_grad.typed_data(), a.typed_data(), b_plus_1->typed_data(),
a_grad->typed_data(), b_grad->typed_data(), n);
cudaError_t last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
return ffi::Error(