#ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ #include #include #include #include #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/optimization.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" namespace jax { // Returns from the function if the argument is an ffi::Error. #define FFI_RETURN_IF_ERROR(...) \ do { \ ::xla::ffi::Error err = (__VA_ARGS__); \ if (ABSL_PREDICT_FALSE(err.failure())) { \ return err; \ } \ } while (0) // Returns from the function with an ffi::Error if the argument is an // absl::Status. #define FFI_RETURN_IF_ERROR_STATUS(...) \ do { \ ::absl::Status status = (__VA_ARGS__); \ if (ABSL_PREDICT_FALSE(!status.ok())) { \ return ::jax::AsFfiError(status); \ } \ } while (0) // Returns from the function with an ffi::Error if the RHS is an absl::Status, // otherwise assigns to the LHS. Most of the complication here stems from the // fact that we want to support having the LHS wrapped in parentheses (when // unpacking a tuple, for example). #define FFI_ASSIGN_OR_RETURN(lhs, rhs) \ FFI_ASSIGN_OR_RETURN_IMPL_( \ FFI_ASSIGN_OR_RETURN_CONCAT_(_status_or_value, __LINE__), lhs, rhs) #define FFI_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rhs) \ auto statusor = (rhs); \ if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ return ::jax::AsFfiError(statusor.status()); \ } \ FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ (*std::move(statusor)) #define FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) x##y #define FFI_ASSIGN_OR_RETURN_CONCAT_(x, y) \ FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) // All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN // where the LHS is wrapped in parentheses. See a more detailed discussion at // https://stackoverflow.com/a/62984543 #define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \ FFI_ASSIGN_OR_RETURN_ESCAPE(FFI_ASSIGN_OR_RETURN_EMPTY X) #define FFI_ASSIGN_OR_RETURN_EMPTY(...) FFI_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__ #define FFI_ASSIGN_OR_RETURN_ESCAPE(...) \ FFI_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__) #define FFI_ASSIGN_OR_RETURN_ESCAPE_(...) FFI_ASSIGN_OR_RETURN_##__VA_ARGS__ #define FFI_ASSIGN_OR_RETURN_FFI_ASSIGN_OR_RETURN_EMPTY template inline absl::StatusOr MaybeCastNoOverflow( std::int64_t value, const std::string& source = __FILE__) { if constexpr (sizeof(T) == sizeof(std::int64_t)) { return value; } else { if (value > std::numeric_limits::max()) [[unlikely]] { return absl::InvalidArgumentError(absl::StrFormat( "%s: Value (=%d) exceeds the maximum representable value of the " "desired type", source, value)); } return static_cast(value); } } inline ::xla::ffi::Error AsFfiError(const absl::Status& status) { if (ABSL_PREDICT_FALSE(!status.ok())) { return ::xla::ffi::Error(static_cast(status.code()), std::string(status.message())); } else { return ::xla::ffi::Error::Success(); } } inline int64_t GetBatchSize(::xla::ffi::Span dims) { return absl::c_accumulate(dims, 1, std::multiplies()); } inline absl::StatusOr> SplitBatch1D( ::xla::ffi::Span dims, const std::string& source = __FILE__) { if (dims.size() < 1) { return absl::InvalidArgumentError( absl::StrFormat("%s: Argument must have at least 1 dimension", source)); } return std::make_pair(GetBatchSize(dims.first(dims.size() - 1)), dims.back()); } inline absl::StatusOr> SplitBatch2D( ::xla::ffi::Span dims, const std::string& source = __FILE__) { if (dims.size() < 2) { return absl::InvalidArgumentError(absl::StrFormat( "%s: Argument must have at least 2 dimensions", source)); } auto trailingDims = dims.last(2); return std::make_tuple(GetBatchSize(dims.first(dims.size() - 2)), trailingDims.front(), trailingDims.back()); } inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, int64_t expected_batch, std::string_view name, std::string_view op) { auto batch = GetBatchSize(dimensions); if (batch != expected_batch) { return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( "Invalid total batch size for input %s to %s. Expected %d, got %d.", name, op, expected_batch, batch)); } return ::xla::ffi::Error::Success(); } inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, std::tuple shape, std::string_view name, std::string_view op) { FFI_ASSIGN_OR_RETURN((auto [batch, size]), SplitBatch1D(dimensions)); auto [expected_batch, expected_size] = shape; if (batch != expected_batch) { return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( "Invalid total batch size for input %s to %s. Expected %d, got %d.", name, op, expected_batch, batch)); } if (batch != expected_batch || size != expected_size) { return ::xla::ffi::Error::InvalidArgument( absl::StrFormat("Invalid trailing dimension for input %s " "to %s. Expected %d, got %d.", name, op, expected_size, size)); } return ::xla::ffi::Error::Success(); } inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, std::tuple shape, std::string_view name, std::string_view op) { FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), SplitBatch2D(dimensions)); auto [expected_batch, expected_rows, expected_cols] = shape; if (batch != expected_batch) { return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( "Invalid total batch size for input %s to %s. Expected %d, got %d.", name, op, expected_batch, batch)); } if (rows != expected_rows || cols != expected_cols) { return ::xla::ffi::Error::InvalidArgument( absl::StrFormat("Invalid matrix dimensions for input %s to %s. " "Expected (%d, %d), got (%d, %d).", name, op, expected_rows, expected_cols, rows, cols)); } return ::xla::ffi::Error::Success(); } template <::xla::ffi::DataType dtype> auto AllocateScratchMemory(std::size_t size) -> std::unique_ptr>[]> { // TODO(paruzelp): use std::make_unique_for_overwrite when C++20 is available. using ValueType = std::remove_extent_t<::xla::ffi::NativeType>; return std::unique_ptr(new ValueType[size]); } template inline absl::StatusOr AllocateWorkspace( ::xla::ffi::ScratchAllocator& scratch, int64_t size, std::string_view name) { auto maybe_workspace = scratch.Allocate(sizeof(T) * size); if (!maybe_workspace.has_value()) { return absl::Status( absl::StatusCode::kResourceExhausted, absl::StrFormat("Unable to allocate workspace for %s", name)); } return static_cast(maybe_workspace.value()); } } // namespace jax #endif // JAXLIB_FFI_HELPERS_H_