mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[xla:gpu] Make cu_threefry2x32 custom call compatible with command buffers
PiperOrigin-RevId: 600937786
This commit is contained in:
parent
85fc4a95f2
commit
88f5eaca3e
@ -29,7 +29,12 @@ __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
|
||||
const std::uint32_t* data0,
|
||||
const std::uint32_t* data1,
|
||||
std::uint32_t* out0, std::uint32_t* out1,
|
||||
std::int64_t n) {
|
||||
std::int64_t n, const std::int64_t* n_ptr) {
|
||||
if (n < 0) {
|
||||
// n is stored in device memory.
|
||||
n = *n_ptr;
|
||||
}
|
||||
|
||||
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n;
|
||||
idx += blockDim.x * gridDim.x) {
|
||||
// Rotation distances specified by the Threefry2x32 algorithm.
|
||||
@ -109,12 +114,10 @@ void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
|
||||
std::int64_t n = descriptor.n;
|
||||
int output_idx = 4;
|
||||
std::int64_t* n_ptr = nullptr;
|
||||
if (n < 0) {
|
||||
// n is an operand in device memory.
|
||||
gpuMemcpyAsync((void*)&n, reinterpret_cast<const std::int64_t*>(buffers[4]),
|
||||
sizeof(n), gpuMemcpyDeviceToHost,
|
||||
stream);
|
||||
gpuStreamSynchronize(stream);
|
||||
n_ptr = reinterpret_cast<std::int64_t*>(buffers[4]);
|
||||
output_idx = 5;
|
||||
}
|
||||
|
||||
@ -123,10 +126,11 @@ void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
out[1] = reinterpret_cast<std::uint32_t*>(buffers[output_idx + 1]);
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim =
|
||||
std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
|
||||
n < 0 ? 32
|
||||
: std::min<std::int64_t>(1024, (n + block_dim - 1) / block_dim);
|
||||
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
|
||||
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
|
||||
out[1], n);
|
||||
out[1], n, n_ptr);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
|
Loading…
x
Reference in New Issue
Block a user