rocm_jax/jaxlib/gpu/rnn_kernels.cc
2022-11-22 18:53:29 -08:00

550 lines
21 KiB
C++

/* Copyright 2022 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/rnn_kernels.h"
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
using DnnHandlePool = HandlePool<gpudnnHandle_t, gpuStream_t>;
template <>
/*static*/ absl::StatusOr<DnnHandlePool::Handle> DnnHandlePool::Borrow(
gpuStream_t stream) {
DnnHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
gpudnnHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
namespace JAX_GPU_NAMESPACE {
// struct RnnDescriptor {
// int input_size;
// int hidden_size;
// int num_layers;
// int batch_size;
// int max_seq_length;
// float dropout;
// bool bidirectional;
// int workspace_size;
// int reserve_space_size;
// };
static absl::StatusOr<std::pair<int, int>>
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
int num_layers, int batch_size,
int max_seq_length, float dropout,
bool bidirectional) {
auto h = DnnHandlePool::Borrow();
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cudnnRNNDescriptor_t rnn_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
cudnnDropoutDescriptor_t dropout_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
size_t state_size;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
dropout_desc, handle.get(), dropout, nullptr, state_size, 123)));
// TODO(zhangqiaorjc): Handle other kinds of RNN.
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
int num_directions = 1;
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
if (bidirectional) {
dir_mode = CUDNN_BIDIRECTIONAL;
num_directions = 2;
}
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
int32_t proj_size = hidden_size;
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
input_mode, data_type, math_prec, math_type, input_size, hidden_size,
proj_size, num_layers, dropout_desc, aux_flags)));
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;
std::vector<int32_t> seq_length_vector(batch_size, max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
input_data_desc, data_type, layout, max_seq_length, batch_size,
input_size, seq_length_array, &padding)));
size_t workSpaceSize;
size_t reserveSpaceSize;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnGetRNNTempSpaceSizes(
handle.get(), rnn_desc, fwdMode, input_data_desc, &workSpaceSize,
&reserveSpaceSize)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
// Round up to nearest multiples of 4 so we can return them as f32 arrays.
workSpaceSize += (workSpaceSize % 4);
reserveSpaceSize += (reserveSpaceSize % 4);
return std::make_pair(workSpaceSize, reserveSpaceSize);
}
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional) {
return DoRnnComputeWorkspaceReserveSpaceSizes(
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
bidirectional);
}
static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<RnnDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const RnnDescriptor& d = **s;
auto h = DnnHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cudnnRNNDescriptor_t rnn_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
cudnnDropoutDescriptor_t dropout_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
size_t state_size;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
// TODO(zhangqiaorjc): Handle other kinds of RNN.
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
int num_directions = 1;
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
if (d.bidirectional) {
dir_mode = CUDNN_BIDIRECTIONAL;
num_directions = 2;
}
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
int32_t proj_size = d.hidden_size;
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
// cudnnStatus_t cudnnSetRNNDescriptor_v8(
// cudnnRNNDescriptor_t rnn_desc,
// cudnnRNNAlgo_t algo,
// cudnnRNNMode_t cell_mode,
// cudnnRNNBiasMode_t bias_mode,
// cudnnDirectionMode_t dir_mode,
// cudnnRNNInputMode_t input_mode,
// cudnnDataType_t data_type,
// cudnnDataType_t math_prec,
// cudnnMathType_t math_type,
// int32_t inputSize,
// int32_t hiddenSize,
// int32_t projSize,
// int32_t numLayers,
// cudnnDropoutDescriptor_t dropout_desc,
// uint32_t auxFlags);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size,
proj_size, d.num_layers, dropout_desc, aux_flags)));
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
// cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_INFERENCE;
// cudnnStatus_t cudnnSetRNNDataDescriptor(
// cudnnRNNDataDescriptor_t RNNDataDesc,
// cudnnDataType_t data_type,
// cudnnRNNDataLayout_t layout,
// int maxSeqLength,
// int batchSize,
// int vectorSize,
// const int seq_length_array[],
// void *paddingFill);
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
input_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.input_size, seq_length_array, &padding)));
cudnnRNNDataDescriptor_t output_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
output_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.hidden_size * num_directions, seq_length_array, &padding)));
// cudnnStatus_t cudnnSetTensor4dDescriptor(
// cudnnTensorDescriptor_t tensorDesc,
// cudnnTensorFormat_t format,
// cudnnDataType_t data_type,
// int n,
// int c,
// int h,
// int w)
// Shape is (num_directions * num_layers, batch_size, hidden_size)
int dims[3];
dims[0] = num_directions * d.num_layers;
dims[1] = d.batch_size;
dims[2] = d.hidden_size;
int strides[3];
strides[0] = dims[1] * dims[2];
strides[1] = dims[2];
strides[2] = 1;
cudnnTensorDescriptor_t h_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides)));
cudnnTensorDescriptor_t c_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides)));
// cudnnStatus_t cudnnGetRNNWeightSpaceSize(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// size_t *weight_space_size);
size_t weight_space_size;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size)));
// cudnnStatus_t cudnnRNNForward(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// cudnnForwardMode_t fwdMode,
// const int32_t devSeqLengths[],
// cudnnRNNDataDescriptor_t xDesc,
// const void *x,
// cudnnRNNDataDescriptor_t yDesc,
// void *y,
// cudnnTensorDescriptor_t h_desc,
// const void *hx,
// void *hy,
// cudnnTensorDescriptor_t cD`esc,
// const void *cx,
// void *cy,
// size_t weight_space_size,
// const void *weightSpace,
// size_t workSpaceSize,
// void *workSpace,
// size_t reserveSpaceSize,
// void *reserveSpace);
auto input_buf = buffers[0];
auto h_0_buf = buffers[1];
auto c_0_buf = buffers[2];
auto weights_buf = buffers[3];
auto seq_lengths_buf = buffers[4];
auto output_buf = buffers[5];
auto h_n_buf = buffers[6];
auto c_n_buf = buffers[7];
auto workspace_buf = buffers[8];
auto reserve_space_buf = buffers[9];
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNForward(
handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf,
input_data_desc, input_buf, output_data_desc, output_buf, h_desc, h_0_buf,
h_n_buf, c_desc, c_0_buf, c_n_buf, weight_space_size, weights_buf,
d.workspace_size, /*workSpace=*/workspace_buf, d.reserve_space_size,
/*reserveSpace=*/reserve_space_buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
return absl::OkStatus();
}
static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<RnnDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const RnnDescriptor& d = **s;
auto h = DnnHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cudnnRNNDescriptor_t rnn_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
cudnnDropoutDescriptor_t dropout_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
size_t state_size;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
// TODO(zhangqiaorjc): Handle other kinds of RNN.
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
int num_directions = 1;
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
if (d.bidirectional) {
dir_mode = CUDNN_BIDIRECTIONAL;
num_directions = 2;
}
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
int32_t proj_size = d.hidden_size;
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
// cudnnStatus_t cudnnSetRNNDescriptor_v8(
// cudnnRNNDescriptor_t rnn_desc,
// cudnnRNNAlgo_t algo,
// cudnnRNNMode_t cell_mode,
// cudnnRNNBiasMode_t bias_mode,
// cudnnDirectionMode_t dir_mode,
// cudnnRNNInputMode_t input_mode,
// cudnnDataType_t data_type,
// cudnnDataType_t math_prec,
// cudnnMathType_t math_type,
// int32_t inputSize,
// int32_t hiddenSize,
// int32_t projSize,
// int32_t numLayers,
// cudnnDropoutDescriptor_t dropout_desc,
// uint32_t auxFlags);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size,
proj_size, d.num_layers, dropout_desc, aux_flags)));
// cudnnStatus_t cudnnSetRNNDataDescriptor(
// cudnnRNNDataDescriptor_t RNNDataDesc,
// cudnnDataType_t data_type,
// cudnnRNNDataLayout_t layout,
// int maxSeqLength,
// int batchSize,
// int vectorSize,
// const int seq_length_array[],
// void *paddingFill);
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
input_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.input_size, seq_length_array, &padding)));
cudnnRNNDataDescriptor_t output_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
output_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.hidden_size * num_directions, seq_length_array, &padding)));
// cudnnStatus_t cudnnSetTensor4dDescriptor(
// cudnnTensorDescriptor_t tensorDesc,
// cudnnTensorFormat_t format,
// cudnnDataType_t data_type,
// int n,
// int c,
// int h,
// int w)
// Shape is (num_directions * num_layers, batch_size, hidden_size)
int dims[3];
dims[0] = num_directions * d.num_layers;
dims[1] = d.batch_size;
dims[2] = d.hidden_size;
int strides[3];
strides[0] = dims[1] * dims[2];
strides[1] = dims[2];
strides[2] = 1;
cudnnTensorDescriptor_t h_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides)));
cudnnTensorDescriptor_t c_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides)));
// cudnnStatus_t cudnnGetRNNWeightSpaceSize(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// size_t *weight_space_size);
size_t weight_space_size;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size)));
auto dy_buf = buffers[0];
auto dh_n_buf = buffers[1];
auto dc_n_buf = buffers[2];
auto x_buf = buffers[3];
auto h_0_buf = buffers[4];
auto c_0_buf = buffers[5];
auto w_buf = buffers[6];
auto y_buf = buffers[7];
auto workspace_buf = buffers[8];
auto reserve_space_buf = buffers[9];
auto zeroed_dw_buf = buffers[10];
auto seq_lengths_buf = buffers[11];
auto dx_buf = buffers[12];
auto dh_0_buf = buffers[13];
auto dc_0_buf = buffers[14];
// auto dw_buf = buffers[15];
// cudnnStatus_t cudnnRNNBackwardData_v8(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// const int32_t devSeqLengths[],
// cudnnRNNDataDescriptor_t yDesc,
// const void *y,
// const void *dy,
// cudnnRNNDataDescriptor_t xDesc,
// void *dx,
// cudnnTensorDescriptor_t h_desc,
// const void *hx,
// const void *dhy,
// void *dhx,
// cudnnTensorDescriptor_t c_desc,
// const void *cx,
// const void *dcy,
// void *dcx,
// size_t weight_space_size,
// const void *weightSpace,
// size_t workSpaceSize,
// void *workSpace,
// size_t reserveSpaceSize,
// void *reserveSpace);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardData_v8(
handle.get(), rnn_desc, (const int32_t*)seq_lengths_buf, output_data_desc,
y_buf, dy_buf, input_data_desc, dx_buf, h_desc, h_0_buf, dh_n_buf,
dh_0_buf, c_desc, c_0_buf, dc_n_buf, dc_0_buf, weight_space_size, w_buf,
d.workspace_size, workspace_buf, d.reserve_space_size,
reserve_space_buf)));
// cudnnStatus_t cudnnRNNBackwardWeights_v8(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// cudnnWgradMode_t addGrad,
// const int32_t devSeqLengths[],
// cudnnRNNDataDescriptor_t xDesc,
// const void *x,
// cudnnTensorDescriptor_t h_desc,
// const void *hx,
// cudnnRNNDataDescriptor_t yDesc,
// const void *y,
// size_t weight_space_size,
// void *dweightSpace,
// size_t workSpaceSize,
// void *workSpace,
// size_t reserveSpaceSize,
// void *reserveSpace);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardWeights_v8(
handle.get(), rnn_desc, CUDNN_WGRAD_MODE_ADD,
(const int32_t*)seq_lengths_buf, input_data_desc, x_buf, h_desc, h_0_buf,
output_data_desc, y_buf, weight_space_size, zeroed_dw_buf,
d.workspace_size, workspace_buf, d.reserve_space_size,
reserve_space_buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
return absl::OkStatus();
}
void RNNForward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = DnnRNNForward_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = DnnRNNBackward_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax