[xla:gpu] Make cu_threefry2x32 custom call compatible with command buffers

PiperOrigin-RevId: 600937786
This commit is contained in:
Anlun Xu 2024-01-23 16:13:45 -08:00 committed by jax authors
parent 85fc4a95f2
commit 88f5eaca3e

View File

@ -29,7 +29,12 @@ __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
const std::uint32_t* data0,
const std::uint32_t* data1,
std::uint32_t* out0, std::uint32_t* out1,
std::int64_t n) {
std::int64_t n, const std::int64_t* n_ptr) {
if (n < 0) {
// n is stored in device memory.
n = *n_ptr;
}
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n;
idx += blockDim.x * gridDim.x) {
// Rotation distances specified by the Threefry2x32 algorithm.
@ -109,12 +114,10 @@ void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
std::int64_t n = descriptor.n;
int output_idx = 4;
std::int64_t* n_ptr = nullptr;
if (n < 0) {
// n is an operand in device memory.
gpuMemcpyAsync((void*)&n, reinterpret_cast<const std::int64_t*>(buffers[4]),
sizeof(n), gpuMemcpyDeviceToHost,
stream);
gpuStreamSynchronize(stream);
n_ptr = reinterpret_cast<std::int64_t*>(buffers[4]);
output_idx = 5;
}
@ -123,10 +126,11 @@ void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
out[1] = reinterpret_cast<std::uint32_t*>(buffers[output_idx + 1]);
const int block_dim = 128;
const std::int64_t grid_dim =
std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
n < 0 ? 32
: std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
out[1], n);
out[1], n, n_ptr);
}
} // namespace JAX_GPU_NAMESPACE