Avoid throwing exceptions in LAPACK CPU kernels.

When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing.
Instead, it should be safer to handle mapping failures to ffi::Error manually.

PiperOrigin-RevId: 647348889
This commit is contained in:
Dan Foreman-Mackey 2024-06-27 09:40:33 -07:00 committed by jax authors
parent 61185a21ee
commit 98b87540a7
2 changed files with 54 additions and 9 deletions

View File

@ -40,6 +40,8 @@ cc_library(
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
],
)

View File

@ -30,6 +30,8 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
@ -39,28 +41,61 @@ static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),
namespace ffi = xla::ffi;
// TODO(danfm): These macros and the casting functions should be moved to a
// separate header for use in other FFI kernels.
#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \
if (!rhs.ok()) { \
return ffi::Error(static_cast<XLA_FFI_Error_Code>(rhs.status().code()), \
std::string(rhs.status().message())); \
} \
lhs = rhs.value()
#define RETURN_IF_FFI_ERROR(...) \
do { \
ffi::Error err = (__VA_ARGS__); \
if (err.failure()) { \
return err; \
} \
} while (0)
namespace {
template <typename T>
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
inline absl::StatusOr<T> MaybeCastNoOverflow(
int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(int64_t)) {
return value;
} else {
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
throw std::overflow_error{
return absl::InvalidArgumentError(
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
"value of the desired type",
source, value)};
source, value));
}
return static_cast<T>(value);
}
}
template <typename T>
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
if (dims.size() < 2) {
throw std::invalid_argument("Matrix must have at least 2 dimensions");
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
auto result = MaybeCastNoOverflow<T>(value, source);
if (!result.ok()) {
throw std::overflow_error{std::string(result.status().message())};
}
return result.value();
}
template <typename T>
ffi::Error CheckMatrixDimensions(ffi::Span<T> dims) {
if (dims.size() < 2) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"Matrix must have at least 2 dimensions");
}
return ffi::Error::Success();
}
template <typename T>
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
auto matrix_dims = dims.last(2);
return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1,
std::multiplies<int64_t>()),
@ -201,6 +236,7 @@ ffi::Error LuDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<LapackIntDtype> ipiv,
ffi::ResultBuffer<LapackIntDtype> info) {
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions));
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
auto* x_out_data = x_out->data;
auto* ipiv_data = ipiv->data;
@ -208,8 +244,10 @@ ffi::Error LuDecomposition<dtype>::Kernel(
CopyIfDiffBuffer(x, x_out);
auto x_rows_v = CastNoOverflow<lapack_int>(x_rows);
auto x_cols_v = CastNoOverflow<lapack_int>(x_cols);
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
MaybeCastNoOverflow<lapack_int>(x_rows));
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
auto x_leading_dim_v = x_rows_v;
const int64_t x_out_step{x_rows * x_cols};
@ -371,6 +409,7 @@ template <ffi::DataType dtype>
ffi::Error CholeskyFactorization<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions));
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
auto* x_out_data = x_out->data;
auto* info_data = info->data;
@ -378,7 +417,8 @@ ffi::Error CholeskyFactorization<dtype>::Kernel(
CopyIfDiffBuffer(x, x_out);
auto uplo_v = static_cast<char>(uplo);
auto x_order_v = CastNoOverflow<lapack_int>(x.dimensions.back());
ASSIGN_OR_RETURN_FFI_ERROR(
auto x_order_v, MaybeCastNoOverflow<lapack_int>(x.dimensions.back()));
auto x_leading_dim_v = x_order_v;
const int64_t x_out_step{x_rows * x_cols};
@ -1077,3 +1117,6 @@ template struct Sytrd<std::complex<float>>;
template struct Sytrd<std::complex<double>>;
} // namespace jax
#undef ASSIGN_OR_RETURN_FFI_ERROR
#undef RETURN_IF_FFI_ERROR