From 33a9db3943b8bc287832c01ba064608245221367 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 10 Jul 2024 15:08:58 -0700 Subject: [PATCH] Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h. Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to be useful for other FFI calls. This change consolidates these macros in the ffi_helper header. PiperOrigin-RevId: 651166306 --- jaxlib/BUILD | 3 ++ jaxlib/cpu/BUILD | 4 +- jaxlib/cpu/lapack_kernels.cc | 74 ++++++++---------------------------- jaxlib/cuda/BUILD | 1 + jaxlib/ffi_helpers.h | 34 +++++++++++++++++ jaxlib/gpu/linalg_kernels.cc | 15 ++------ jaxlib/gpu/prng_kernels.cc | 9 ++--- 7 files changed, 61 insertions(+), 79 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index bdb8ca29b..467258406 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -128,6 +128,9 @@ cc_library( hdrs = ["ffi_helpers.h"], features = ["-use_header_modules"], deps = [ + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 3a50a2916..d6dba6e31 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -35,14 +35,12 @@ cc_library( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ + "//jaxlib:ffi_helpers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", ], ) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 6fe788e1d..df533c913 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include #include +#include #include -#include #include #include #include @@ -30,55 +30,21 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" +#include "jaxlib/ffi_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" +#include "xla/service/custom_call_status.h" static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), "Expected LAPACK integers to be 32-bit"); namespace ffi = xla::ffi; -// TODO(danfm): These macros and the casting functions should be moved to a -// separate header for use in other FFI kernels. -#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \ - if (!rhs.ok()) { \ - return ffi::Error(static_cast(rhs.status().code()), \ - std::string(rhs.status().message())); \ - } \ - lhs = rhs.value() - -#define RETURN_IF_FFI_ERROR(...) \ - do { \ - ffi::Error err = (__VA_ARGS__); \ - if (err.failure()) { \ - return err; \ - } \ - } while (0) - namespace { -template -inline absl::StatusOr MaybeCastNoOverflow( - int64_t value, const std::string& source = __FILE__) { - if constexpr (sizeof(T) == sizeof(int64_t)) { - return value; - } else { - if (value > std::numeric_limits::max()) [[unlikely]] { - return absl::InvalidArgumentError( - absl::StrFormat("%s: Value (=%d) exceeds the maximum representable " - "value of the desired type", - source, value)); - } - return static_cast(value); - } -} - template inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { - auto result = MaybeCastNoOverflow(value, source); + auto result = jax::MaybeCastNoOverflow(value, source); if (!result.ok()) { throw std::overflow_error{std::string(result.status().message())}; } @@ -237,7 +203,7 @@ ffi::Error LuDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer ipiv, ffi::ResultBuffer info) { - RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions())); + FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions())); auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); auto* x_out_data = x_out->typed_data(); auto* ipiv_data = ipiv->typed_data(); @@ -245,10 +211,8 @@ ffi::Error LuDecomposition::Kernel( CopyIfDiffBuffer(x, x_out); - ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v, - MaybeCastNoOverflow(x_rows)); - ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v, - MaybeCastNoOverflow(x_cols)); + FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); auto x_leading_dim_v = x_rows_v; const int64_t x_out_step{x_rows * x_cols}; @@ -410,7 +374,7 @@ template ffi::Error CholeskyFactorization::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer info) { - RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions())); + FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions())); auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); auto* x_out_data = x_out->typed_data(); auto* info_data = info->typed_data(); @@ -418,8 +382,8 @@ ffi::Error CholeskyFactorization::Kernel( CopyIfDiffBuffer(x, x_out); auto uplo_v = static_cast(uplo); - ASSIGN_OR_RETURN_FFI_ERROR( - auto x_order_v, MaybeCastNoOverflow(x.dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto x_order_v, + MaybeCastNoOverflow(x.dimensions().back())); auto x_leading_dim_v = x_order_v; const int64_t x_out_step{x_rows * x_cols}; @@ -626,21 +590,18 @@ static ffi::Error SvdKernel( CopyIfDiffBuffer(x, x_out); - ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v, - MaybeCastNoOverflow(x_rows)); - ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v, - MaybeCastNoOverflow(x_cols)); + FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); auto mode_v = static_cast(mode); - ASSIGN_OR_RETURN_FFI_ERROR( - auto workspace_dim_v, - MaybeCastNoOverflow(work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( + work->dimensions().back())); auto x_leading_dim_v = x_rows_v; auto u_leading_dim_v = x_rows_v; auto u_dims = u->dimensions().last(2); auto vt_dims = vt->dimensions().last(2); - ASSIGN_OR_RETURN_FFI_ERROR(auto vt_leading_dim_v, - MaybeCastNoOverflow(vt_dims.front())); + FFI_ASSIGN_OR_RETURN(auto vt_leading_dim_v, + MaybeCastNoOverflow(vt_dims.front())); const int64_t x_out_step{x_rows * x_cols}; const int64_t singular_values_step{singular_values->dimensions().back()}; @@ -1282,6 +1243,3 @@ template struct Sytrd>; template struct Sytrd>; } // namespace jax - -#undef ASSIGN_OR_RETURN_FFI_ERROR -#undef RETURN_IF_FFI_ERROR diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 083a6cf7e..63e300a64 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -343,6 +343,7 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels_impl", ":cuda_vendor", + "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 8bec08cf6..3acafa62a 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -5,12 +5,46 @@ #include #include +#include "absl/base/optimization.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" namespace jax { +#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \ + if (ABSL_PREDICT_FALSE(!rhs.ok())) { \ + return ::jax::AsFfiError(rhs.status()); \ + } \ + lhs = rhs.value() + +#define FFI_RETURN_IF_ERROR(...) \ + do { \ + ::xla::ffi::Error err = (__VA_ARGS__); \ + if (ABSL_PREDICT_FALSE(err.failure())) { \ + return err; \ + } \ + } while (0) + +#define FFI_RETURN_IF_ERROR_STATUS(...) \ + do { \ + ::absl::Status status = (__VA_ARGS__); \ + if (ABSL_PREDICT_FALSE(!status.ok())) { \ + return ::jax::AsFfiError(status); \ + } \ + } while (0) + +inline xla::ffi::Error AsFfiError(const absl::Status& status) { + if (ABSL_PREDICT_FALSE(!status.ok())) { + return xla::ffi::Error(static_cast(status.code()), + std::string(status.message())); + } else { + return xla::ffi::Error::Success(); + } +} + template inline absl::StatusOr MaybeCastNoOverflow( std::int64_t value, const std::string& source = __FILE__) { diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 43b1d5344..6636f5654 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -30,7 +30,6 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" -#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" @@ -71,13 +70,8 @@ ffi::Error LuPivotsToPermutationImpl( return ffi::Error(ffi::ErrorCode::kInvalidArgument, "pivots must have at least one dimension"); } - auto maybe_pivot_size = MaybeCastNoOverflow(dims.back()); - if (!maybe_pivot_size.ok()) { - return ffi::Error( - static_cast(maybe_pivot_size.status().code()), - std::string(maybe_pivot_size.status().message())); - } - std::int32_t pivot_size = maybe_pivot_size.value(); + FFI_ASSIGN_OR_RETURN(std::int32_t pivot_size, + MaybeCastNoOverflow(dims.back())); std::int64_t batch_size = 1; if (dims.size() >= 2) { batch_size = @@ -86,10 +80,7 @@ ffi::Error LuPivotsToPermutationImpl( LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, permutation_size, pivots.typed_data(), permutation->typed_data()); - if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { - return ffi::Error(static_cast(status.code()), - std::string(status.message())); - } + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); return ffi::Error::Success(); } } // namespace diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index 6bd1f47c1..609c3364c 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -17,14 +17,14 @@ limitations under the License. #include #include -#include #include #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/ffi_helpers.h" #include "jaxlib/kernel_helpers.h" -#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" @@ -70,10 +70,7 @@ ffi::Error ThreeFry2x32Impl(gpuStream_t stream, LaunchThreeFry2x32KernelFfi(stream, n, keys0.typed_data(), keys1.typed_data(), data0.typed_data(), data1.typed_data(), out0->typed_data(), out1->typed_data()); - if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) { - return ffi::Error(static_cast(status.code()), - std::string(status.message())); - } + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); return ffi::Error::Success(); } } // namespace