mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Avoid throwing exceptions in LAPACK CPU kernels.
When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing. Instead, it should be safer to handle mapping failures to ffi::Error manually. PiperOrigin-RevId: 647348889
This commit is contained in:
parent
61185a21ee
commit
98b87540a7
@ -40,6 +40,8 @@ cc_library(
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
@ -30,6 +30,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
@ -39,28 +41,61 @@ static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
// TODO(danfm): These macros and the casting functions should be moved to a
|
||||
// separate header for use in other FFI kernels.
|
||||
#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \
|
||||
if (!rhs.ok()) { \
|
||||
return ffi::Error(static_cast<XLA_FFI_Error_Code>(rhs.status().code()), \
|
||||
std::string(rhs.status().message())); \
|
||||
} \
|
||||
lhs = rhs.value()
|
||||
|
||||
#define RETURN_IF_FFI_ERROR(...) \
|
||||
do { \
|
||||
ffi::Error err = (__VA_ARGS__); \
|
||||
if (err.failure()) { \
|
||||
return err; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
|
||||
inline absl::StatusOr<T> MaybeCastNoOverflow(
|
||||
int64_t value, const std::string& source = __FILE__) {
|
||||
if constexpr (sizeof(T) == sizeof(int64_t)) {
|
||||
return value;
|
||||
} else {
|
||||
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
|
||||
throw std::overflow_error{
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
|
||||
"value of the desired type",
|
||||
source, value)};
|
||||
source, value));
|
||||
}
|
||||
return static_cast<T>(value);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
|
||||
if (dims.size() < 2) {
|
||||
throw std::invalid_argument("Matrix must have at least 2 dimensions");
|
||||
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
|
||||
auto result = MaybeCastNoOverflow<T>(value, source);
|
||||
if (!result.ok()) {
|
||||
throw std::overflow_error{std::string(result.status().message())};
|
||||
}
|
||||
return result.value();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ffi::Error CheckMatrixDimensions(ffi::Span<T> dims) {
|
||||
if (dims.size() < 2) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"Matrix must have at least 2 dimensions");
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
|
||||
auto matrix_dims = dims.last(2);
|
||||
return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1,
|
||||
std::multiplies<int64_t>()),
|
||||
@ -201,6 +236,7 @@ ffi::Error LuDecomposition<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<LapackIntDtype> ipiv,
|
||||
ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions));
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
|
||||
auto* x_out_data = x_out->data;
|
||||
auto* ipiv_data = ipiv->data;
|
||||
@ -208,8 +244,10 @@ ffi::Error LuDecomposition<dtype>::Kernel(
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto x_rows_v = CastNoOverflow<lapack_int>(x_rows);
|
||||
auto x_cols_v = CastNoOverflow<lapack_int>(x_cols);
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
@ -371,6 +409,7 @@ template <ffi::DataType dtype>
|
||||
ffi::Error CholeskyFactorization<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions));
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
|
||||
auto* x_out_data = x_out->data;
|
||||
auto* info_data = info->data;
|
||||
@ -378,7 +417,8 @@ ffi::Error CholeskyFactorization<dtype>::Kernel(
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
auto x_order_v = CastNoOverflow<lapack_int>(x.dimensions.back());
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(
|
||||
auto x_order_v, MaybeCastNoOverflow<lapack_int>(x.dimensions.back()));
|
||||
auto x_leading_dim_v = x_order_v;
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
@ -1077,3 +1117,6 @@ template struct Sytrd<std::complex<float>>;
|
||||
template struct Sytrd<std::complex<double>>;
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#undef ASSIGN_OR_RETURN_FFI_ERROR
|
||||
#undef RETURN_IF_FFI_ERROR
|
||||
|
Loading…
x
Reference in New Issue
Block a user