mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[GPU] Fix another instance of missing stream synchronization in RNN kernels.
PiperOrigin-RevId: 530660502
This commit is contained in:
parent
a2b5bd5230
commit
a89c377762
@ -201,10 +201,10 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
|
||||
auto seq_lengths_buf = buffers[4];
|
||||
std::vector<int32_t> seq_length_vector(d.batch_size, 0);
|
||||
int32_t* seq_length_array = &seq_length_vector[0];
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
seq_length_array, seq_lengths_buf,
|
||||
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost,
|
||||
stream)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(gpuMemcpyAsync(seq_length_array, seq_lengths_buf,
|
||||
seq_length_vector.size() * sizeof(int32_t),
|
||||
gpuMemcpyDeviceToHost, stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
|
||||
cudnnRNNDataDescriptor_t input_data_desc;
|
||||
@ -321,9 +321,11 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
auto seq_lengths_buf = buffers[10];
|
||||
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
|
||||
int32_t* seq_length_array = &seq_length_vector[0];
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpy(
|
||||
seq_length_array, seq_lengths_buf,
|
||||
seq_length_vector.size() * sizeof(int32_t), gpuMemcpyDeviceToHost)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(gpuMemcpyAsync(seq_length_array, seq_lengths_buf,
|
||||
seq_length_vector.size() * sizeof(int32_t),
|
||||
gpuMemcpyDeviceToHost, stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
|
||||
cudnnRNNDataDescriptor_t input_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
|
Loading…
x
Reference in New Issue
Block a user