[GPU] Fix another instance of missing stream synchronization in RNN kernels.

PiperOrigin-RevId: 530660502
This commit is contained in:
Peter Hawkins 2023-05-09 11:07:44 -07:00 committed by jax authors
parent a2b5bd5230
commit a89c377762

View File

@ -201,10 +201,10 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
auto seq_lengths_buf = buffers[4];
std::vector<int32_t> seq_length_vector(d.batch_size, 0);
int32_t* seq_length_array = &seq_length_vector[0];
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
seq_length_array, seq_lengths_buf,
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost,
stream)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(gpuMemcpyAsync(seq_length_array, seq_lengths_buf,
seq_length_vector.size() * sizeof(int32_t),
gpuMemcpyDeviceToHost, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
cudnnRNNDataDescriptor_t input_data_desc;
@ -321,9 +321,11 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
auto seq_lengths_buf = buffers[10];
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpy(
seq_length_array, seq_lengths_buf,
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(gpuMemcpyAsync(seq_length_array, seq_lengths_buf,
seq_length_vector.size() * sizeof(int32_t),
gpuMemcpyDeviceToHost, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(