From c6e83903de5a7d02d8c2d7fe145338441501f374 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 6 Feb 2025 18:27:10 -0800 Subject: [PATCH] Update RNN kernels to use FFI. PiperOrigin-RevId: 724151647 --- jaxlib/cuda/BUILD | 2 + jaxlib/gpu/rnn.cc | 8 +- jaxlib/gpu/rnn_kernels.cc | 145 ++++++++++++++++++--------------- jaxlib/gpu/rnn_kernels.h | 19 +++-- jaxlib/gpu_rnn.py | 20 +++-- jaxlib/rocm/BUILD | 2 + tests/experimental_rnn_test.py | 4 +- 7 files changed, 114 insertions(+), 86 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 8bfc25d80..395144ba1 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -154,12 +154,14 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + ":ffi_wrapper", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index 9f6040198..c88b164e6 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -29,17 +29,19 @@ namespace nb = nanobind; nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32, - int workspace_size, int reserve_space_size) { + int workspace_size, int reserve_space_size) { return PackDescriptor(RnnDescriptor{ input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, - bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size - }); + bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size}); } nb::dict Registrations() { nb::dict dict; dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward); dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward); + dict[JAX_GPU_PREFIX "dnn_rnn_ffi"] = EncapsulateFfiHandler(RNNForwardFfi); + dict[JAX_GPU_PREFIX "dnn_rnn_bwd_ffi"] = + EncapsulateFfiHandler(RNNBackwardFfi); return dict; } diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index a1e4cc377..89a6d0a30 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" @@ -75,7 +76,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional, - bool cudnn_allow_tf32) { + bool cudnn_allow_tf32) { auto h = DnnHandlePool::Borrow(/*stream=*/nullptr); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; @@ -92,18 +93,21 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, #ifdef JAX_GPU_HIP void* dropout_states_dev = nullptr; - // Allocate minimal memory for dropout states (can be very small since it's not used) - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size))); + // Allocate minimal memory for dropout states (can be very small since it's + // not used) + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size))); if (!dropout_states_dev) { - return absl::InternalError("Failed to allocate minimal GPU memory for dropout states."); + return absl::InternalError( + "Failed to allocate minimal GPU memory for dropout states."); } JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( - dropout_desc, handle.get(), dropout, dropout_states_dev, state_size, 123, false, false, - MIOPEN_RNG_PSEUDO_XORWOW))); -#else // JAX_GPU_CUDA + dropout_desc, handle.get(), dropout, dropout_states_dev, state_size, 123, + false, false, MIOPEN_RNG_PSEUDO_XORWOW))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( dropout_desc, handle.get(), dropout, nullptr, state_size, 123))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP // TODO(zhangqiaorjc): Handle other kinds of RNN. gpudnnRNNMode_t cell_mode = GPUDNN_LSTM; @@ -121,16 +125,17 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( rnn_desc, hidden_size, num_layers, dropout_desc, input_mode, dir_mode, cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type))); -#else // JAX_GPU_CUDA +#else // JAX_GPU_CUDA gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT; - gpudnnMathType_t math_type = cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH; + gpudnnMathType_t math_type = + cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH; int32_t proj_size = hidden_size; uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED; JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( rnn_desc, GPUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, input_mode, data_type, math_prec, math_type, input_size, hidden_size, proj_size, num_layers, dropout_desc, aux_flags))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP gpudnnForwardMode_t fwdMode = GPUDNN_FWD_MODE_TRAINING; gpudnnRNNDataLayout_t layout = GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; @@ -149,14 +154,14 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, size_t workSpaceSize; size_t reserveSpaceSize; #ifdef JAX_GPU_HIP - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes( - handle.get(), rnn_desc, input_data_desc, fwdMode, &workSpaceSize, - &reserveSpaceSize))); -#else // JAX_GPU_CUDA + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + gpudnnGetRNNTempSpaceSizes(handle.get(), rnn_desc, input_data_desc, + fwdMode, &workSpaceSize, &reserveSpaceSize))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes( handle.get(), rnn_desc, fwdMode, input_data_desc, &workSpaceSize, &reserveSpaceSize))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR( @@ -199,18 +204,21 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, #ifdef JAX_GPU_HIP void* dropout_states_dev = nullptr; - // Allocate minimal memory for dropout states (can be very small since it's not used). - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size))); + // Allocate minimal memory for dropout states (can be very small since it's + // not used). + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size))); if (!dropout_states_dev) { - return absl::InternalError("Failed to allocate minimal GPU memory for dropout states."); + return absl::InternalError( + "Failed to allocate minimal GPU memory for dropout states."); } JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( - dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false, - MIOPEN_RNG_PSEUDO_XORWOW))); -#else // JAX_GPU_CUDA + dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, + 123, false, false, MIOPEN_RNG_PSEUDO_XORWOW))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP // TODO(zhangqiaorjc): Handle other kinds of RNN. gpudnnRNNMode_t cell_mode = GPUDNN_LSTM; @@ -228,16 +236,17 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( rnn_desc, d.hidden_size, d.num_layers, dropout_desc, input_mode, dir_mode, cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type))); -#else // JAX_GPU_CUDA +#else // JAX_GPU_CUDA gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT; - gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH; + gpudnnMathType_t math_type = + d.cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH; int32_t proj_size = d.hidden_size; uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED; JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( rnn_desc, GPUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size, proj_size, d.num_layers, dropout_desc, aux_flags))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP gpudnnForwardMode_t fwdMode = GPUDNN_FWD_MODE_TRAINING; gpudnnRNNDataLayout_t layout = GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; @@ -288,19 +297,20 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, size_t weight_space_size; #ifdef JAX_GPU_HIP -miopenTensorDescriptor_t input_tensor_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc))); + miopenTensorDescriptor_t input_tensor_desc; + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc))); int dimsA[2] = {d.batch_size, d.input_size}; - int stridesA[2] = {dimsA[1], 1}; // Row-major order, similar to GPUDNN + int stridesA[2] = {dimsA[1], 1}; // Row-major order, similar to GPUDNN JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor( - input_tensor_desc, data_type, 2, dimsA, stridesA))); + input_tensor_desc, data_type, 2, dimsA, stridesA))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, input_tensor_desc, - &weight_space_size, data_type))); -#else // JAX_GPU_CUDA + &weight_space_size, data_type))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS( gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP auto input_buf = buffers[0]; auto h_0_buf = buffers[1]; @@ -314,18 +324,18 @@ miopenTensorDescriptor_t input_tensor_desc; #ifdef JAX_GPU_HIP JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward( - handle.get(), rnn_desc, fwdMode, input_data_desc, input_buf, - h_desc, h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, - output_data_desc, output_buf, weights_buf, weight_space_size, - workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size))); -#else // JAX_GPU_CUDA + handle.get(), rnn_desc, fwdMode, input_data_desc, input_buf, h_desc, + h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, output_data_desc, output_buf, + weights_buf, weight_space_size, workspace_buf, d.workspace_size, + reserve_space_buf, d.reserve_space_size))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward( handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf, input_data_desc, input_buf, output_data_desc, output_buf, h_desc, h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, weight_space_size, weights_buf, d.workspace_size, workspace_buf, d.reserve_space_size, reserve_space_buf))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(h_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(c_desc))); @@ -361,18 +371,21 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, #ifdef JAX_GPU_HIP void* dropout_states_dev = nullptr; - // Allocate minimal memory for dropout states (can be very small since it's not used) - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size))); + // Allocate minimal memory for dropout states (can be very small since it's + // not used) + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size))); if (!dropout_states_dev) { - return absl::InternalError("Failed to allocate minimal GPU memory for dropout states."); + return absl::InternalError( + "Failed to allocate minimal GPU memory for dropout states."); } JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( - dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false, - MIOPEN_RNG_PSEUDO_XORWOW))); -#else // JAX_GPU_CUDA + dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, + 123, false, false, MIOPEN_RNG_PSEUDO_XORWOW))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP // TODO(zhangqiaorjc): Handle other kinds of RNN. gpudnnRNNMode_t cell_mode = GPUDNN_LSTM; @@ -390,16 +403,17 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( rnn_desc, d.hidden_size, d.num_layers, dropout_desc, input_mode, dir_mode, cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type))); -#else // JAX_GPU_CUDA +#else // JAX_GPU_CUDA gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT; - gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH; + gpudnnMathType_t math_type = + d.cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH; int32_t proj_size = d.hidden_size; uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED; JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( rnn_desc, GPUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size, proj_size, d.num_layers, dropout_desc, aux_flags))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP gpudnnRNNDataLayout_t layout = GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; float padding = 0.0f; @@ -449,18 +463,19 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, size_t weight_space_size; #ifdef JAX_GPU_HIP miopenTensorDescriptor_t input_tensor_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc))); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc))); int input_dims[2] = {d.batch_size, d.input_size}; - int input_strides[2] = {input_dims[1], 1}; // row-major order + int input_strides[2] = {input_dims[1], 1}; // row-major order JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor( - input_tensor_desc, data_type, 2, input_dims, input_strides))); + input_tensor_desc, data_type, 2, input_dims, input_strides))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, input_tensor_desc, - &weight_space_size, data_type))); -#else // JAX_GPU_CUDA + &weight_space_size, data_type))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS( gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP auto dy_buf = buffers[0]; auto dh_n_buf = buffers[1]; @@ -482,17 +497,16 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, #ifdef JAX_GPU_HIP JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData( - handle.get(), rnn_desc, output_data_desc, y_buf, dy_buf, - h_desc, h_0_buf, dh_n_buf, dh_0_buf, - c_desc, c_0_buf, dc_n_buf, dc_0_buf, - input_data_desc, dx_buf, w_buf, weight_space_size, - workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size))); + handle.get(), rnn_desc, output_data_desc, y_buf, dy_buf, h_desc, h_0_buf, + dh_n_buf, dh_0_buf, c_desc, c_0_buf, dc_n_buf, dc_0_buf, input_data_desc, + dx_buf, w_buf, weight_space_size, workspace_buf, d.workspace_size, + reserve_space_buf, d.reserve_space_size))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardWeights( handle.get(), rnn_desc, input_data_desc, x_buf, h_desc, h_0_buf, - output_data_desc, y_buf, zeroed_dw_buf, weight_space_size, - workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size))); -#else // JAX_GPU_CUDA + output_data_desc, y_buf, zeroed_dw_buf, weight_space_size, workspace_buf, + d.workspace_size, reserve_space_buf, d.reserve_space_size))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData( handle.get(), rnn_desc, (const int32_t*)seq_lengths_buf, output_data_desc, y_buf, dy_buf, input_data_desc, dx_buf, h_desc, h_0_buf, dh_n_buf, @@ -506,7 +520,7 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, output_data_desc, y_buf, weight_space_size, zeroed_dw_buf, d.workspace_size, workspace_buf, d.reserve_space_size, reserve_space_buf))); -#endif // JAX_GPU_HIP +#endif // JAX_GPU_HIP JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(h_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(c_desc))); @@ -539,5 +553,8 @@ void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque, } } +JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNForwardFfi, DnnRNNForward_); +JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNBackwardFfi, DnnRNNBackward_); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index 5cc332e37..468c02eac 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" namespace jax { @@ -38,11 +39,10 @@ struct RnnDescriptor { }; // Return (workspace size, reserve space size). -absl::StatusOr> -RnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, - int num_layers, int batch_size, - int max_seq_length, float dropout, - bool bidirectional, bool cudnn_allow_tf32); +absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( + int input_size, int hidden_size, int num_layers, int batch_size, + int max_seq_length, float dropout, bool bidirectional, + bool cudnn_allow_tf32); void RNNForward(gpuStream_t stream, void **buffers, const char *opaque, size_t opaque_len, XlaCustomCallStatus *status); @@ -50,7 +50,10 @@ void RNNForward(gpuStream_t stream, void **buffers, const char *opaque, void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque, size_t opaque_len, XlaCustomCallStatus *status); -} // namespace JAX_GPU_NAMESPACE -} // namespace jax +XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNForwardFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNBackwardFfi); -#endif // JAXLIB_GPU_RNN_KERNELS_H_ +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_RNN_KERNELS_H_ diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index 080a7f8bb..d48f3f97e 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -33,7 +33,9 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: if _cuda_rnn: for _name, _value in _cuda_rnn.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform='CUDA') + api_version = 1 if _name.endswith("_ffi") else 0 + xla_client.register_custom_call_target(_name, _value, platform='CUDA', + api_version=api_version) compute_rnn_workspace_reserve_space_sizes = _cuda_rnn.compute_rnn_workspace_reserve_space_sizes @@ -47,7 +49,9 @@ for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: if _hip_rnn: for _name, _value in _hip_rnn.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform='ROCM') + api_version = 1 if _name.endswith("_ffi") else 0 + xla_client.register_custom_call_target(_name, _value, platform='ROCM', + api_version=api_version) compute_rnn_workspace_reserve_space_sizes = _hip_rnn.compute_rnn_workspace_reserve_space_sizes @@ -93,10 +97,10 @@ def _rnn_fwd_lowering(_rnn, platform, ctx, input, h_0, c_0, weights, seq_lengths out = hlo.CustomCallOp( [output_type, h_0.type, c_0.type, workspace_type, reserve_space_type], [input, h_0, c_0, weights, seq_lengths], - call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn"), + call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_ffi"), has_side_effect=ir.BoolAttr.get(False), - backend_config=ir.StringAttr.get(opaque), - api_version=ir.IntegerAttr.get(i32_type, 2), + backend_config=ir.DictAttr.get({"opaque": ir.StringAttr.get(opaque)}), + api_version=ir.IntegerAttr.get(i32_type, 4), called_computations=ir.ArrayAttr.get([]), ) return out.results[:-2] + out.results[-1:] # drop workspace output @@ -140,10 +144,10 @@ def _rnn_bwd_lowering(_rnn, platform, ctx, dy, dhn, dcn, x, h0, c0, w, y, dy, dhn, dcn, x, h0, c0, w, y, reserve_space, zeroed_dw, seq_lengths ], - call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_bwd"), + call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_bwd_ffi"), has_side_effect=ir.BoolAttr.get(False), - backend_config=ir.StringAttr.get(opaque), - api_version=ir.IntegerAttr.get(i32_type, 2), + backend_config=ir.DictAttr.get({"opaque": ir.StringAttr.get(opaque)}), + api_version=ir.IntegerAttr.get(i32_type, 4), called_computations=ir.ArrayAttr.get([]), output_operand_aliases=ir.ArrayAttr.get([ hlo.OutputOperandAlias.get( diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9a0e58f20..91a8fb678 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -140,6 +140,7 @@ cc_library( srcs = ["//jaxlib/gpu:rnn_kernels.cc"], hdrs = ["//jaxlib/gpu:rnn_kernels.h"], deps = [ + ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", "//jaxlib:handle_pool", @@ -149,6 +150,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:miopen", "@local_config_rocm//rocm:rocm_headers", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", ], ) diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 103687db3..376a9b1a1 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -213,9 +213,7 @@ class RnnTest(jtu.JaxTestCase): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - self.assertIn('stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) ' - '{api_version = 2 : i32, backend_config = ' - '"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"}', + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', stablehlo) if __name__ == '__main__':