Use stream-synchronized copy in rnn_kernels.cc.

May fix flaky wrong outputs sometimes seen in CI.

Also check for errors in another use of gpuStreamSynchronize().

PiperOrigin-RevId: 530391917
This commit is contained in:
Peter Hawkins 2023-05-08 13:25:58 -07:00 committed by jax authors
parent 821b38da12
commit 6b9a109939
2 changed files with 5 additions and 3 deletions

View File

@ -201,9 +201,11 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
auto seq_lengths_buf = buffers[4];
std::vector<int32_t> seq_length_vector(d.batch_size, 0);
int32_t* seq_length_array = &seq_length_vector[0];
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpy(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
seq_length_array, seq_lengths_buf,
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost)));
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost,
stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(

View File

@ -465,7 +465,7 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers,
reinterpret_cast<const std::int64_t*>(buffers[1]),
sizeof(batch), gpuMemcpyDeviceToHost,
stream);
gpuStreamSynchronize(stream);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
output_idx = 2;
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(