Update RNN kernels to use FFI.

PiperOrigin-RevId: 724151647
This commit is contained in:
Dan Foreman-Mackey 2025-02-06 18:27:10 -08:00 committed by jax authors
parent cce3df1071
commit c6e83903de
7 changed files with 114 additions and 86 deletions

View File

@ -154,12 +154,14 @@ cc_library(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":ffi_wrapper",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cudnn",

View File

@ -32,14 +32,16 @@ nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
int workspace_size, int reserve_space_size) {
return PackDescriptor(RnnDescriptor{
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size
});
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size});
}
nb::dict Registrations() {
nb::dict dict;
dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward);
dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward);
dict[JAX_GPU_PREFIX "dnn_rnn_ffi"] = EncapsulateFfiHandler(RNNForwardFfi);
dict[JAX_GPU_PREFIX "dnn_rnn_bwd_ffi"] =
EncapsulateFfiHandler(RNNBackwardFfi);
return dict;
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/ffi_wrapper.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
@ -92,14 +93,17 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
#ifdef JAX_GPU_HIP
void* dropout_states_dev = nullptr;
// Allocate minimal memory for dropout states (can be very small since it's not used)
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
// Allocate minimal memory for dropout states (can be very small since it's
// not used)
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
if (!dropout_states_dev) {
return absl::InternalError("Failed to allocate minimal GPU memory for dropout states.");
return absl::InternalError(
"Failed to allocate minimal GPU memory for dropout states.");
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
dropout_desc, handle.get(), dropout, dropout_states_dev, state_size, 123, false, false,
MIOPEN_RNG_PSEUDO_XORWOW)));
dropout_desc, handle.get(), dropout, dropout_states_dev, state_size, 123,
false, false, MIOPEN_RNG_PSEUDO_XORWOW)));
#else // JAX_GPU_CUDA
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
dropout_desc, handle.get(), dropout, nullptr, state_size, 123)));
@ -123,7 +127,8 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type)));
#else // JAX_GPU_CUDA
gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT;
gpudnnMathType_t math_type = cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH;
gpudnnMathType_t math_type =
cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH;
int32_t proj_size = hidden_size;
uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor(
@ -149,9 +154,9 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
size_t workSpaceSize;
size_t reserveSpaceSize;
#ifdef JAX_GPU_HIP
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes(
handle.get(), rnn_desc, input_data_desc, fwdMode, &workSpaceSize,
&reserveSpaceSize)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
gpudnnGetRNNTempSpaceSizes(handle.get(), rnn_desc, input_data_desc,
fwdMode, &workSpaceSize, &reserveSpaceSize)));
#else // JAX_GPU_CUDA
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes(
handle.get(), rnn_desc, fwdMode, input_data_desc, &workSpaceSize,
@ -199,14 +204,17 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
#ifdef JAX_GPU_HIP
void* dropout_states_dev = nullptr;
// Allocate minimal memory for dropout states (can be very small since it's not used).
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
// Allocate minimal memory for dropout states (can be very small since it's
// not used).
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
if (!dropout_states_dev) {
return absl::InternalError("Failed to allocate minimal GPU memory for dropout states.");
return absl::InternalError(
"Failed to allocate minimal GPU memory for dropout states.");
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false,
MIOPEN_RNG_PSEUDO_XORWOW)));
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size,
123, false, false, MIOPEN_RNG_PSEUDO_XORWOW)));
#else // JAX_GPU_CUDA
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
@ -230,7 +238,8 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type)));
#else // JAX_GPU_CUDA
gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT;
gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH;
gpudnnMathType_t math_type =
d.cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH;
int32_t proj_size = d.hidden_size;
uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor(
@ -288,8 +297,9 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
size_t weight_space_size;
#ifdef JAX_GPU_HIP
miopenTensorDescriptor_t input_tensor_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
miopenTensorDescriptor_t input_tensor_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
int dimsA[2] = {d.batch_size, d.input_size};
int stridesA[2] = {dimsA[1], 1}; // Row-major order, similar to GPUDNN
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor(
@ -314,10 +324,10 @@ miopenTensorDescriptor_t input_tensor_desc;
#ifdef JAX_GPU_HIP
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward(
handle.get(), rnn_desc, fwdMode, input_data_desc, input_buf,
h_desc, h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf,
output_data_desc, output_buf, weights_buf, weight_space_size,
workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size)));
handle.get(), rnn_desc, fwdMode, input_data_desc, input_buf, h_desc,
h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, output_data_desc, output_buf,
weights_buf, weight_space_size, workspace_buf, d.workspace_size,
reserve_space_buf, d.reserve_space_size)));
#else // JAX_GPU_CUDA
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward(
handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf,
@ -361,14 +371,17 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
#ifdef JAX_GPU_HIP
void* dropout_states_dev = nullptr;
// Allocate minimal memory for dropout states (can be very small since it's not used)
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
// Allocate minimal memory for dropout states (can be very small since it's
// not used)
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
if (!dropout_states_dev) {
return absl::InternalError("Failed to allocate minimal GPU memory for dropout states.");
return absl::InternalError(
"Failed to allocate minimal GPU memory for dropout states.");
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false,
MIOPEN_RNG_PSEUDO_XORWOW)));
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size,
123, false, false, MIOPEN_RNG_PSEUDO_XORWOW)));
#else // JAX_GPU_CUDA
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
@ -392,7 +405,8 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type)));
#else // JAX_GPU_CUDA
gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT;
gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH;
gpudnnMathType_t math_type =
d.cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH;
int32_t proj_size = d.hidden_size;
uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor(
@ -449,7 +463,8 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
size_t weight_space_size;
#ifdef JAX_GPU_HIP
miopenTensorDescriptor_t input_tensor_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
int input_dims[2] = {d.batch_size, d.input_size};
int input_strides[2] = {input_dims[1], 1}; // row-major order
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor(
@ -482,16 +497,15 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
#ifdef JAX_GPU_HIP
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData(
handle.get(), rnn_desc, output_data_desc, y_buf, dy_buf,
h_desc, h_0_buf, dh_n_buf, dh_0_buf,
c_desc, c_0_buf, dc_n_buf, dc_0_buf,
input_data_desc, dx_buf, w_buf, weight_space_size,
workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size)));
handle.get(), rnn_desc, output_data_desc, y_buf, dy_buf, h_desc, h_0_buf,
dh_n_buf, dh_0_buf, c_desc, c_0_buf, dc_n_buf, dc_0_buf, input_data_desc,
dx_buf, w_buf, weight_space_size, workspace_buf, d.workspace_size,
reserve_space_buf, d.reserve_space_size)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardWeights(
handle.get(), rnn_desc, input_data_desc, x_buf, h_desc, h_0_buf,
output_data_desc, y_buf, zeroed_dw_buf, weight_space_size,
workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size)));
output_data_desc, y_buf, zeroed_dw_buf, weight_space_size, workspace_buf,
d.workspace_size, reserve_space_buf, d.reserve_space_size)));
#else // JAX_GPU_CUDA
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData(
handle.get(), rnn_desc, (const int32_t*)seq_lengths_buf, output_data_desc,
@ -539,5 +553,8 @@ void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
}
}
JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNForwardFfi, DnnRNNForward_);
JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNBackwardFfi, DnnRNNBackward_);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_status.h"
namespace jax {
@ -38,11 +39,10 @@ struct RnnDescriptor {
};
// Return (workspace size, reserve space size).
absl::StatusOr<std::pair<int, int>>
RnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
int num_layers, int batch_size,
int max_seq_length, float dropout,
bool bidirectional, bool cudnn_allow_tf32);
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional,
bool cudnn_allow_tf32);
void RNNForward(gpuStream_t stream, void **buffers, const char *opaque,
size_t opaque_len, XlaCustomCallStatus *status);
@ -50,6 +50,9 @@ void RNNForward(gpuStream_t stream, void **buffers, const char *opaque,
void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque,
size_t opaque_len, XlaCustomCallStatus *status);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNForwardFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNBackwardFfi);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -33,7 +33,9 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
if _cuda_rnn:
for _name, _value in _cuda_rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
api_version = 1 if _name.endswith("_ffi") else 0
xla_client.register_custom_call_target(_name, _value, platform='CUDA',
api_version=api_version)
compute_rnn_workspace_reserve_space_sizes = _cuda_rnn.compute_rnn_workspace_reserve_space_sizes
@ -47,7 +49,9 @@ for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
if _hip_rnn:
for _name, _value in _hip_rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='ROCM')
api_version = 1 if _name.endswith("_ffi") else 0
xla_client.register_custom_call_target(_name, _value, platform='ROCM',
api_version=api_version)
compute_rnn_workspace_reserve_space_sizes = _hip_rnn.compute_rnn_workspace_reserve_space_sizes
@ -93,10 +97,10 @@ def _rnn_fwd_lowering(_rnn, platform, ctx, input, h_0, c_0, weights, seq_lengths
out = hlo.CustomCallOp(
[output_type, h_0.type, c_0.type, workspace_type, reserve_space_type],
[input, h_0, c_0, weights, seq_lengths],
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn"),
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_ffi"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
backend_config=ir.DictAttr.get({"opaque": ir.StringAttr.get(opaque)}),
api_version=ir.IntegerAttr.get(i32_type, 4),
called_computations=ir.ArrayAttr.get([]),
)
return out.results[:-2] + out.results[-1:] # drop workspace output
@ -140,10 +144,10 @@ def _rnn_bwd_lowering(_rnn, platform, ctx, dy, dhn, dcn, x, h0, c0, w, y,
dy, dhn, dcn, x, h0, c0, w, y, reserve_space, zeroed_dw,
seq_lengths
],
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_bwd"),
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_bwd_ffi"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
backend_config=ir.DictAttr.get({"opaque": ir.StringAttr.get(opaque)}),
api_version=ir.IntegerAttr.get(i32_type, 4),
called_computations=ir.ArrayAttr.get([]),
output_operand_aliases=ir.ArrayAttr.get([
hlo.OutputOperandAlias.get(

View File

@ -140,6 +140,7 @@ cc_library(
srcs = ["//jaxlib/gpu:rnn_kernels.cc"],
hdrs = ["//jaxlib/gpu:rnn_kernels.h"],
deps = [
":ffi_wrapper",
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:handle_pool",
@ -149,6 +150,7 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@local_config_rocm//rocm:miopen",
"@local_config_rocm//rocm:rocm_headers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
)

View File

@ -213,9 +213,7 @@ class RnnTest(jtu.JaxTestCase):
k = jax.random.split(jax.random.PRNGKey(1), 4)
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
self.assertIn('stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) '
'{api_version = 2 : i32, backend_config = '
'"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"}',
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
stablehlo)
if __name__ == '__main__':