Merge pull request #19288 from pearu:pearu/int32-overflow

PiperOrigin-RevId: 608701959
This commit is contained in:
jax authors 2024-02-20 12:43:16 -08:00
commit 16b29a6930
2 changed files with 25 additions and 19 deletions

View File

@ -33,6 +33,8 @@ cc_library(
name = "lapack_kernels",
srcs = ["lapack_kernels.cc"],
hdrs = ["lapack_kernels.h"],
copts = ["-fexceptions"],
features = ["-use_header_modules"],
deps = [
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/base:dynamic_annotations",

View File

@ -22,6 +22,21 @@ limitations under the License.
#include "absl/base/dynamic_annotations.h"
namespace {
inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) {
if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) {
return value;
} else {
if (value > std::numeric_limits<jax::lapack_int>::max()) {
throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int");
}
return value;
}
}
}
namespace jax {
static_assert(sizeof(lapack_int) == sizeof(int32_t),
@ -252,9 +267,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) {
}
lapack_int GesddIworkSize(int64_t m, int64_t n) {
// Avoid integer overflow; the LAPACK integer type is int32.
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
8 * std::min(m, n));
return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n));
}
template <typename T>
@ -320,12 +333,10 @@ int64_t RealGesdd<T>::Workspace(lapack_int m, lapack_int n,
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) {
int64_t mn = std::min(m, n);
if (compute_uv == 0) {
return 7 * mn;
return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn);
}
int64_t mx = std::max(m, n);
// Avoid integer overflow; the LAPACK integer type is int32.
return std::min<int64_t>(
std::numeric_limits<lapack_int>::max(),
return catch_lapack_int_overflow("complex gesdd rwork",
std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn));
}
@ -399,14 +410,11 @@ template struct ComplexGesdd<std::complex<double>>;
// # Workspace sizes, taken from the LAPACK documentation.
lapack_int SyevdWorkSize(int64_t n) {
// Avoids int32 overflow.
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
1 + 6 * n + 2 * n * n);
return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n);
}
lapack_int SyevdIworkSize(int64_t n) {
// Avoids int32 overflow.
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(), 3 + 5 * n);
return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n);
}
template <typename T>
@ -446,15 +454,11 @@ void RealSyevd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
// Workspace sizes, taken from the LAPACK documentation.
lapack_int HeevdWorkSize(int64_t n) {
// Avoid int32 overflow.
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
1 + 2 * n + n * n);
return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n);
}
lapack_int HeevdRworkSize(int64_t n) {
// Avoid int32 overflow.
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
1 + 5 * n + 2 * n * n);
return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n);
}
template <typename T>