rocm_jax/jaxlib/ffi_helpers.h
Dan Foreman-Mackey 71a93d0c87 Port QR factorization GPU kernel to FFI.
The biggest change here is that we now ignore the `info` parameter that is returned by `getrf`. In the previous implementation, we would return an error in the batched implementation, or set the relevant matrix entries to NaN in the non-batched version if `info != 0`. But, since info is only used for shape checking (see LAPACK, cuBLAS and cuSolver docs), I argue that we will never see `info != 0`, because we're including all the shape checks in the kernel already.

PiperOrigin-RevId: 665307128
2024-08-20 05:07:04 -07:00

210 lines
8.2 KiB
C++

#ifndef JAXLIB_FFI_HELPERS_H_
#define JAXLIB_FFI_HELPERS_H_
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <string_view>
#include <tuple>
#include <type_traits>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/base/optimization.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace jax {
// Returns from the function if the argument is an ffi::Error.
#define FFI_RETURN_IF_ERROR(...) \
do { \
::xla::ffi::Error err = (__VA_ARGS__); \
if (ABSL_PREDICT_FALSE(err.failure())) { \
return err; \
} \
} while (0)
// Returns from the function with an ffi::Error if the argument is an
// absl::Status.
#define FFI_RETURN_IF_ERROR_STATUS(...) \
do { \
::absl::Status status = (__VA_ARGS__); \
if (ABSL_PREDICT_FALSE(!status.ok())) { \
return ::jax::AsFfiError(status); \
} \
} while (0)
// Returns from the function with an ffi::Error if the RHS is an absl::Status,
// otherwise assigns to the LHS. Most of the complication here stems from the
// fact that we want to support having the LHS wrapped in parentheses (when
// unpacking a tuple, for example).
#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \
FFI_ASSIGN_OR_RETURN_IMPL_( \
FFI_ASSIGN_OR_RETURN_CONCAT_(_status_or_value, __LINE__), lhs, rhs)
#define FFI_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rhs) \
auto statusor = (rhs); \
if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
return ::jax::AsFfiError(statusor.status()); \
} \
FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \
(*std::move(statusor))
#define FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) x##y
#define FFI_ASSIGN_OR_RETURN_CONCAT_(x, y) \
FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y)
// All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN
// where the LHS is wrapped in parentheses.
#define FFI_ASSIGN_OR_RETURN_EAT(...)
#define FFI_ASSIGN_OR_RETURN_REM(...) __VA_ARGS__
#define FFI_ASSIGN_OR_RETURN_EMPTY()
#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(...) \
FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1))
#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER(args) \
FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I args
#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty
#define FFI_ASSIGN_OR_RETURN_IS_EMPTY(...) \
FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(__VA_ARGS__)
#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(...) \
FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(_, ##__VA_ARGS__)
#define FFI_ASSIGN_OR_RETURN_IF_1(_Then, _Else) _Then
#define FFI_ASSIGN_OR_RETURN_IF_0(_Then, _Else) _Else
#define FFI_ASSIGN_OR_RETURN_IF(_Cond, _Then, _Else) \
FFI_ASSIGN_OR_RETURN_CONCAT_(FFI_ASSIGN_OR_RETURN_IF_, _Cond)(_Then, _Else)
#define FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(...) \
FFI_ASSIGN_OR_RETURN_IS_EMPTY(FFI_ASSIGN_OR_RETURN_EAT __VA_ARGS__)
#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(...) \
FFI_ASSIGN_OR_RETURN_IF(FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(__VA_ARGS__), \
FFI_ASSIGN_OR_RETURN_REM, \
FFI_ASSIGN_OR_RETURN_EMPTY()) \
__VA_ARGS__
template <typename T>
inline absl::StatusOr<T> MaybeCastNoOverflow(
std::int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(std::int64_t)) {
return value;
} else {
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
return absl::InvalidArgumentError(absl::StrFormat(
"%s: Value (=%d) exceeds the maximum representable value of the "
"desired type",
source, value));
}
return static_cast<T>(value);
}
}
inline ::xla::ffi::Error AsFfiError(const absl::Status& status) {
if (ABSL_PREDICT_FALSE(!status.ok())) {
return ::xla::ffi::Error(static_cast<XLA_FFI_Error_Code>(status.code()),
std::string(status.message()));
} else {
return ::xla::ffi::Error::Success();
}
}
inline int64_t GetBatchSize(::xla::ffi::Span<const int64_t> dims) {
return absl::c_accumulate(dims, 1, std::multiplies<int64_t>());
}
inline absl::StatusOr<std::pair<int64_t, int64_t>> SplitBatch1D(
::xla::ffi::Span<const int64_t> dims,
const std::string& source = __FILE__) {
if (dims.size() < 1) {
return absl::InvalidArgumentError(
absl::StrFormat("%s: Argument must have at least 1 dimension", source));
}
return std::make_pair(GetBatchSize(dims.first(dims.size() - 1)), dims.back());
}
inline absl::StatusOr<std::tuple<int64_t, int64_t, int64_t>> SplitBatch2D(
::xla::ffi::Span<const int64_t> dims,
const std::string& source = __FILE__) {
if (dims.size() < 2) {
return absl::InvalidArgumentError(absl::StrFormat(
"%s: Argument must have at least 2 dimensions", source));
}
auto trailingDims = dims.last(2);
return std::make_tuple(GetBatchSize(dims.first(dims.size() - 2)),
trailingDims.front(), trailingDims.back());
}
inline ::xla::ffi::Error CheckShape(::xla::ffi::Span<const int64_t> dimensions,
int64_t expected_batch,
std::string_view name,
std::string_view op) {
auto batch = GetBatchSize(dimensions);
if (batch != expected_batch) {
return ::xla::ffi::Error::InvalidArgument(absl::StrFormat(
"Invalid total batch size for input %s to %s. Expected %d, got %d.",
name, op, expected_batch, batch));
}
return ::xla::ffi::Error::Success();
}
inline ::xla::ffi::Error CheckShape(::xla::ffi::Span<const int64_t> dimensions,
std::tuple<int64_t, int64_t> shape,
std::string_view name,
std::string_view op) {
FFI_ASSIGN_OR_RETURN((auto [batch, size]), SplitBatch1D(dimensions));
auto [expected_batch, expected_size] = shape;
if (batch != expected_batch) {
return ::xla::ffi::Error::InvalidArgument(absl::StrFormat(
"Invalid total batch size for input %s to %s. Expected %d, got %d.",
name, op, expected_batch, batch));
}
if (batch != expected_batch || size != expected_size) {
return ::xla::ffi::Error::InvalidArgument(
absl::StrFormat("Invalid trailing dimension for input %s "
"to %s. Expected %d, got %d.",
name, op, expected_size, size));
}
return ::xla::ffi::Error::Success();
}
inline ::xla::ffi::Error CheckShape(::xla::ffi::Span<const int64_t> dimensions,
std::tuple<int64_t, int64_t, int64_t> shape,
std::string_view name,
std::string_view op) {
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), SplitBatch2D(dimensions));
auto [expected_batch, expected_rows, expected_cols] = shape;
if (batch != expected_batch) {
return ::xla::ffi::Error::InvalidArgument(absl::StrFormat(
"Invalid total batch size for input %s to %s. Expected %d, got %d.",
name, op, expected_batch, batch));
}
if (rows != expected_rows || cols != expected_cols) {
return ::xla::ffi::Error::InvalidArgument(
absl::StrFormat("Invalid matrix dimensions for input %s to %s. "
"Expected (%d, %d), got (%d, %d).",
name, op, expected_rows, expected_cols, rows, cols));
}
return ::xla::ffi::Error::Success();
}
template <::xla::ffi::DataType dtype>
auto AllocateScratchMemory(std::size_t size)
-> std::unique_ptr<std::remove_extent_t<::xla::ffi::NativeType<dtype>>[]> {
// TODO(paruzelp): use std::make_unique_for_overwrite when C++20 is available.
using ValueType = std::remove_extent_t<::xla::ffi::NativeType<dtype>>;
return std::unique_ptr<ValueType[]>(new ValueType[size]);
}
} // namespace jax
#endif // JAXLIB_FFI_HELPERS_H_