mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Use stream-synchronized copy in rnn_kernels.cc.
May fix flaky wrong outputs sometimes seen in CI. Also check for errors in another use of gpuStreamSynchronize(). PiperOrigin-RevId: 530391917
This commit is contained in:
parent
821b38da12
commit
6b9a109939
@ -201,9 +201,11 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
|
||||
auto seq_lengths_buf = buffers[4];
|
||||
std::vector<int32_t> seq_length_vector(d.batch_size, 0);
|
||||
int32_t* seq_length_array = &seq_length_vector[0];
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpy(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
seq_length_array, seq_lengths_buf,
|
||||
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost)));
|
||||
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost,
|
||||
stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
|
||||
cudnnRNNDataDescriptor_t input_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
|
@ -465,7 +465,7 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers,
|
||||
reinterpret_cast<const std::int64_t*>(buffers[1]),
|
||||
sizeof(batch), gpuMemcpyDeviceToHost,
|
||||
stream);
|
||||
gpuStreamSynchronize(stream);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
output_idx = 2;
|
||||
}
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
|
Loading…
x
Reference in New Issue
Block a user