mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to be useful for other FFI calls. This change consolidates these macros in the ffi_helper header. PiperOrigin-RevId: 651166306
This commit is contained in:
parent
d0de7970d4
commit
33a9db3943
@ -128,6 +128,9 @@ cc_library(
|
||||
hdrs = ["ffi_helpers.h"],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -35,14 +35,12 @@ cc_library(
|
||||
copts = ["-fexceptions"],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"//jaxlib:ffi_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
#include <complex>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
@ -30,55 +30,21 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/ffi_helpers.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),
|
||||
"Expected LAPACK integers to be 32-bit");
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
// TODO(danfm): These macros and the casting functions should be moved to a
|
||||
// separate header for use in other FFI kernels.
|
||||
#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \
|
||||
if (!rhs.ok()) { \
|
||||
return ffi::Error(static_cast<XLA_FFI_Error_Code>(rhs.status().code()), \
|
||||
std::string(rhs.status().message())); \
|
||||
} \
|
||||
lhs = rhs.value()
|
||||
|
||||
#define RETURN_IF_FFI_ERROR(...) \
|
||||
do { \
|
||||
ffi::Error err = (__VA_ARGS__); \
|
||||
if (err.failure()) { \
|
||||
return err; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline absl::StatusOr<T> MaybeCastNoOverflow(
|
||||
int64_t value, const std::string& source = __FILE__) {
|
||||
if constexpr (sizeof(T) == sizeof(int64_t)) {
|
||||
return value;
|
||||
} else {
|
||||
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
|
||||
"value of the desired type",
|
||||
source, value));
|
||||
}
|
||||
return static_cast<T>(value);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
|
||||
auto result = MaybeCastNoOverflow<T>(value, source);
|
||||
auto result = jax::MaybeCastNoOverflow<T>(value, source);
|
||||
if (!result.ok()) {
|
||||
throw std::overflow_error{std::string(result.status().message())};
|
||||
}
|
||||
@ -237,7 +203,7 @@ ffi::Error LuDecomposition<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<LapackIntDtype> ipiv,
|
||||
ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions()));
|
||||
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions()));
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* ipiv_data = ipiv->typed_data();
|
||||
@ -245,10 +211,8 @@ ffi::Error LuDecomposition<dtype>::Kernel(
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
@ -410,7 +374,7 @@ template <ffi::DataType dtype>
|
||||
ffi::Error CholeskyFactorization<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions()));
|
||||
FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions()));
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
@ -418,8 +382,8 @@ ffi::Error CholeskyFactorization<dtype>::Kernel(
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(
|
||||
auto x_order_v, MaybeCastNoOverflow<lapack_int>(x.dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_order_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x.dimensions().back()));
|
||||
auto x_leading_dim_v = x_order_v;
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
@ -626,21 +590,18 @@ static ffi::Error SvdKernel(
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(
|
||||
auto workspace_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
work->dimensions().back()));
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
auto u_leading_dim_v = x_rows_v;
|
||||
|
||||
auto u_dims = u->dimensions().last(2);
|
||||
auto vt_dims = vt->dimensions().last(2);
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto vt_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
|
||||
FFI_ASSIGN_OR_RETURN(auto vt_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
const int64_t singular_values_step{singular_values->dimensions().back()};
|
||||
@ -1282,6 +1243,3 @@ template struct Sytrd<std::complex<float>>;
|
||||
template struct Sytrd<std::complex<double>>;
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#undef ASSIGN_OR_RETURN_FFI_ERROR
|
||||
#undef RETURN_IF_FFI_ERROR
|
||||
|
@ -343,6 +343,7 @@ cc_library(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_prng_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:ffi_helpers",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
|
@ -5,12 +5,46 @@
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \
|
||||
if (ABSL_PREDICT_FALSE(!rhs.ok())) { \
|
||||
return ::jax::AsFfiError(rhs.status()); \
|
||||
} \
|
||||
lhs = rhs.value()
|
||||
|
||||
#define FFI_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
::xla::ffi::Error err = (__VA_ARGS__); \
|
||||
if (ABSL_PREDICT_FALSE(err.failure())) { \
|
||||
return err; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define FFI_RETURN_IF_ERROR_STATUS(...) \
|
||||
do { \
|
||||
::absl::Status status = (__VA_ARGS__); \
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) { \
|
||||
return ::jax::AsFfiError(status); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
inline xla::ffi::Error AsFfiError(const absl::Status& status) {
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||
return xla::ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
|
||||
std::string(status.message()));
|
||||
} else {
|
||||
return xla::ffi::Error::Success();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline absl::StatusOr<T> MaybeCastNoOverflow(
|
||||
std::int64_t value, const std::string& source = __FILE__) {
|
||||
|
@ -30,7 +30,6 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
@ -71,13 +70,8 @@ ffi::Error LuPivotsToPermutationImpl(
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"pivots must have at least one dimension");
|
||||
}
|
||||
auto maybe_pivot_size = MaybeCastNoOverflow<std::int32_t>(dims.back());
|
||||
if (!maybe_pivot_size.ok()) {
|
||||
return ffi::Error(
|
||||
static_cast<XLA_FFI_Error_Code>(maybe_pivot_size.status().code()),
|
||||
std::string(maybe_pivot_size.status().message()));
|
||||
}
|
||||
std::int32_t pivot_size = maybe_pivot_size.value();
|
||||
FFI_ASSIGN_OR_RETURN(std::int32_t pivot_size,
|
||||
MaybeCastNoOverflow<std::int32_t>(dims.back()));
|
||||
std::int64_t batch_size = 1;
|
||||
if (dims.size() >= 2) {
|
||||
batch_size =
|
||||
@ -86,10 +80,7 @@ ffi::Error LuPivotsToPermutationImpl(
|
||||
LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size,
|
||||
permutation_size, pivots.typed_data(),
|
||||
permutation->typed_data());
|
||||
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
|
||||
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
|
||||
std::string(status.message()));
|
||||
}
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -17,14 +17,14 @@ limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/ffi_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
@ -70,10 +70,7 @@ ffi::Error ThreeFry2x32Impl(gpuStream_t stream,
|
||||
LaunchThreeFry2x32KernelFfi(stream, n, keys0.typed_data(), keys1.typed_data(),
|
||||
data0.typed_data(), data1.typed_data(),
|
||||
out0->typed_data(), out1->typed_data());
|
||||
if (auto status = JAX_AS_STATUS(gpuGetLastError()); !status.ok()) {
|
||||
return ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
|
||||
std::string(status.message()));
|
||||
}
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user