From 3fa1033ac1970e2d25ea85a1d39cf3271a2af0a5 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 10 Jan 2024 17:50:56 +0200 Subject: [PATCH] Prevent silent overflow in lapack worker size calculations. Add -fexceptions to building lapack_kernels --- jaxlib/cpu/BUILD | 2 ++ jaxlib/cpu/lapack_kernels.cc | 42 ++++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 10998d475..5d7a5f614 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -37,6 +37,8 @@ cc_library( "@xla//xla/service:custom_call_status", "@com_google_absl//absl/base:dynamic_annotations", ], + copts = ["-fexceptions"], + features = ["-use_header_modules"], ) cc_library( diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 211d11020..00b54bab0 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -22,6 +22,21 @@ limitations under the License. #include "absl/base/dynamic_annotations.h" +namespace { + +inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) { + if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) { + return value; + } else { + if (value > std::numeric_limits::max()) { + throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int"); + } + return value; + } +} + +} + namespace jax { static_assert(sizeof(lapack_int) == sizeof(int32_t), @@ -252,9 +267,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { } lapack_int GesddIworkSize(int64_t m, int64_t n) { - // Avoid integer overflow; the LAPACK integer type is int32. - return std::min(std::numeric_limits::max(), - 8 * std::min(m, n)); + return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n)); } template @@ -320,13 +333,11 @@ int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { int64_t mn = std::min(m, n); if (compute_uv == 0) { - return 7 * mn; + return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn); } int64_t mx = std::max(m, n); - // Avoid integer overflow; the LAPACK integer type is int32. - return std::min( - std::numeric_limits::max(), - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)); + return catch_lapack_int_overflow("complex gesdd rwork", + std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)); } template @@ -399,14 +410,11 @@ template struct ComplexGesdd>; // # Workspace sizes, taken from the LAPACK documentation. lapack_int SyevdWorkSize(int64_t n) { - // Avoids int32 overflow. - return std::min(std::numeric_limits::max(), - 1 + 6 * n + 2 * n * n); + return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n); } lapack_int SyevdIworkSize(int64_t n) { - // Avoids int32 overflow. - return std::min(std::numeric_limits::max(), 3 + 5 * n); + return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n); } template @@ -446,15 +454,11 @@ void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { // Workspace sizes, taken from the LAPACK documentation. lapack_int HeevdWorkSize(int64_t n) { - // Avoid int32 overflow. - return std::min(std::numeric_limits::max(), - 1 + 2 * n + n * n); + return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n); } lapack_int HeevdRworkSize(int64_t n) { - // Avoid int32 overflow. - return std::min(std::numeric_limits::max(), - 1 + 5 * n + 2 * n * n); + return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n); } template