Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.

Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.

PiperOrigin-RevId: 651166306
This commit is contained in:
Dan Foreman-Mackey 2024-07-10 15:08:58 -07:00 committed by jax authors
parent d0de7970d4
commit 33a9db3943
7 changed files with 61 additions and 79 deletions

View File

@ -128,6 +128,9 @@ cc_library(
hdrs = ["ffi_helpers.h"],
features = ["-use_header_modules"],
deps = [
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",

View File

@ -35,14 +35,12 @@ cc_library(
copts = ["-fexceptions"],
features = ["-use_header_modules"],
deps = [
"//jaxlib:ffi_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
],
)

View File

@ -20,8 +20,8 @@ limitations under the License.
#include <complex>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <functional>
#include <limits>
#include <optional>
#include <stdexcept>
#include <string>
@ -30,55 +30,21 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "jaxlib/ffi_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_status.h"
static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),
"Expected LAPACK integers to be 32-bit");
namespace ffi = xla::ffi;
// TODO(danfm): These macros and the casting functions should be moved to a
// separate header for use in other FFI kernels.
#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \
if (!rhs.ok()) { \
return ffi::Error(static_cast<XLA_FFI_Error_Code>(rhs.status().code()), \
std::string(rhs.status().message())); \
} \
lhs = rhs.value()
#define RETURN_IF_FFI_ERROR(...) \
do { \
ffi::Error err = (__VA_ARGS__); \
if (err.failure()) { \
return err; \
} \
} while (0)
namespace {
template <typename T>
inline absl::StatusOr<T> MaybeCastNoOverflow(
int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(int64_t)) {
return value;
} else {
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
return absl::InvalidArgumentError(
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
"value of the desired type",
source, value));
}
return static_cast<T>(value);
}
}
template <typename T>
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
auto result = MaybeCastNoOverflow<T>(value, source);
auto result = jax::MaybeCastNoOverflow<T>(value, source);
if (!result.ok()) {
throw std::overflow_error{std::string(result.status().message())};
}
@ -237,7 +203,7 @@ ffi::Error LuDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<LapackIntDtype> ipiv,
ffi::ResultBuffer<LapackIntDtype> info) {
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions()));
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions()));
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
auto* x_out_data = x_out->typed_data();
auto* ipiv_data = ipiv->typed_data();
@ -245,10 +211,8 @@ ffi::Error LuDecomposition<dtype>::Kernel(
CopyIfDiffBuffer(x, x_out);
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
MaybeCastNoOverflow<lapack_int>(x_rows));
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
auto x_leading_dim_v = x_rows_v;
const int64_t x_out_step{x_rows * x_cols};
@ -410,7 +374,7 @@ template <ffi::DataType dtype>
ffi::Error CholeskyFactorization<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions()));
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions()));
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
auto* x_out_data = x_out->typed_data();
auto* info_data = info->typed_data();
@ -418,8 +382,8 @@ ffi::Error CholeskyFactorization<dtype>::Kernel(
CopyIfDiffBuffer(x, x_out);
auto uplo_v = static_cast<char>(uplo);
ASSIGN_OR_RETURN_FFI_ERROR(
auto x_order_v, MaybeCastNoOverflow<lapack_int>(x.dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto x_order_v,
MaybeCastNoOverflow<lapack_int>(x.dimensions().back()));
auto x_leading_dim_v = x_order_v;
const int64_t x_out_step{x_rows * x_cols};
@ -626,21 +590,18 @@ static ffi::Error SvdKernel(
CopyIfDiffBuffer(x, x_out);
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
MaybeCastNoOverflow<lapack_int>(x_rows));
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
auto mode_v = static_cast<char>(mode);
ASSIGN_OR_RETURN_FFI_ERROR(
auto workspace_dim_v,
MaybeCastNoOverflow<lapack_int>(work->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
work->dimensions().back()));
auto x_leading_dim_v = x_rows_v;
auto u_leading_dim_v = x_rows_v;
auto u_dims = u->dimensions().last(2);
auto vt_dims = vt->dimensions().last(2);
ASSIGN_OR_RETURN_FFI_ERROR(auto vt_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
FFI_ASSIGN_OR_RETURN(auto vt_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
const int64_t x_out_step{x_rows * x_cols};
const int64_t singular_values_step{singular_values->dimensions().back()};
@ -1282,6 +1243,3 @@ template struct Sytrd<std::complex<float>>;
template struct Sytrd<std::complex<double>>;
} // namespace jax
#undef ASSIGN_OR_RETURN_FFI_ERROR
#undef RETURN_IF_FFI_ERROR

View File

@ -343,6 +343,7 @@ cc_library(
":cuda_gpu_kernel_helpers",
":cuda_prng_kernels_impl",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib:kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",

View File

@ -5,12 +5,46 @@
#include <limits>
#include <string>
#include "absl/base/optimization.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace jax {
#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \
if (ABSL_PREDICT_FALSE(!rhs.ok())) { \
return ::jax::AsFfiError(rhs.status()); \
} \
lhs = rhs.value()
#define FFI_RETURN_IF_ERROR(...) \
do { \
::xla::ffi::Error err = (__VA_ARGS__); \
if (ABSL_PREDICT_FALSE(err.failure())) { \
return err; \
} \
} while (0)
#define FFI_RETURN_IF_ERROR_STATUS(...) \
do { \
::absl::Status status = (__VA_ARGS__); \
if (ABSL_PREDICT_FALSE(!status.ok())) { \
return ::jax::AsFfiError(status); \
} \
} while (0)
inline xla::ffi::Error AsFfiError(const absl::Status& status) {
if (ABSL_PREDICT_FALSE(!status.ok())) {
return xla::ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
} else {
return xla::ffi::Error::Success();
}
}
template <typename T>
inline absl::StatusOr<T> MaybeCastNoOverflow(
std::int64_t value, const std::string& source = __FILE__) {

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_status.h"
@ -71,13 +70,8 @@ ffi::Error LuPivotsToPermutationImpl(
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"pivots must have at least one dimension");
}
auto maybe_pivot_size = MaybeCastNoOverflow<std::int32_t>(dims.back());
if (!maybe_pivot_size.ok()) {
return ffi::Error(
static_cast<XLA_FFI_Error_Code>(maybe_pivot_size.status().code()),
std::string(maybe_pivot_size.status().message()));
}
std::int32_t pivot_size = maybe_pivot_size.value();
FFI_ASSIGN_OR_RETURN(std::int32_t pivot_size,
MaybeCastNoOverflow<std::int32_t>(dims.back()));
std::int64_t batch_size = 1;
if (dims.size() >= 2) {
batch_size =
@ -86,10 +80,7 @@ ffi::Error LuPivotsToPermutationImpl(
LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size,
permutation_size, pivots.typed_data(),
permutation->typed_data());
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
}
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
return ffi::Error::Success();
}
} // namespace

View File

@ -17,14 +17,14 @@ limitations under the License.
#include <cstdint>
#include <functional>
#include <string>
#include <string_view>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/ffi_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_status.h"
@ -70,10 +70,7 @@ ffi::Error ThreeFry2x32Impl(gpuStream_t stream,
LaunchThreeFry2x32KernelFfi(stream, n, keys0.typed_data(), keys1.typed_data(),
data0.typed_data(), data1.typed_data(),
out0->typed_data(), out1->typed_data());
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
}
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
return ffi::Error::Success();
}
} // namespace