mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #19288 from pearu:pearu/int32-overflow
PiperOrigin-RevId: 608701959
This commit is contained in:
commit
16b29a6930
@ -33,6 +33,8 @@ cc_library(
|
||||
name = "lapack_kernels",
|
||||
srcs = ["lapack_kernels.cc"],
|
||||
hdrs = ["lapack_kernels.h"],
|
||||
copts = ["-fexceptions"],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
|
@ -22,6 +22,21 @@ limitations under the License.
|
||||
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
|
||||
namespace {
|
||||
|
||||
inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) {
|
||||
if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) {
|
||||
return value;
|
||||
} else {
|
||||
if (value > std::numeric_limits<jax::lapack_int>::max()) {
|
||||
throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int");
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
namespace jax {
|
||||
|
||||
static_assert(sizeof(lapack_int) == sizeof(int32_t),
|
||||
@ -252,9 +267,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) {
|
||||
}
|
||||
|
||||
lapack_int GesddIworkSize(int64_t m, int64_t n) {
|
||||
// Avoid integer overflow; the LAPACK integer type is int32.
|
||||
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
|
||||
8 * std::min(m, n));
|
||||
return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -320,12 +333,10 @@ int64_t RealGesdd<T>::Workspace(lapack_int m, lapack_int n,
|
||||
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) {
|
||||
int64_t mn = std::min(m, n);
|
||||
if (compute_uv == 0) {
|
||||
return 7 * mn;
|
||||
return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn);
|
||||
}
|
||||
int64_t mx = std::max(m, n);
|
||||
// Avoid integer overflow; the LAPACK integer type is int32.
|
||||
return std::min<int64_t>(
|
||||
std::numeric_limits<lapack_int>::max(),
|
||||
return catch_lapack_int_overflow("complex gesdd rwork",
|
||||
std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn));
|
||||
}
|
||||
|
||||
@ -399,14 +410,11 @@ template struct ComplexGesdd<std::complex<double>>;
|
||||
|
||||
// # Workspace sizes, taken from the LAPACK documentation.
|
||||
lapack_int SyevdWorkSize(int64_t n) {
|
||||
// Avoids int32 overflow.
|
||||
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
|
||||
1 + 6 * n + 2 * n * n);
|
||||
return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n);
|
||||
}
|
||||
|
||||
lapack_int SyevdIworkSize(int64_t n) {
|
||||
// Avoids int32 overflow.
|
||||
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(), 3 + 5 * n);
|
||||
return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -446,15 +454,11 @@ void RealSyevd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
|
||||
// Workspace sizes, taken from the LAPACK documentation.
|
||||
lapack_int HeevdWorkSize(int64_t n) {
|
||||
// Avoid int32 overflow.
|
||||
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
|
||||
1 + 2 * n + n * n);
|
||||
return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n);
|
||||
}
|
||||
|
||||
lapack_int HeevdRworkSize(int64_t n) {
|
||||
// Avoid int32 overflow.
|
||||
return std::min<int64_t>(std::numeric_limits<lapack_int>::max(),
|
||||
1 + 5 * n + 2 * n * n);
|
||||
return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
Loading…
x
Reference in New Issue
Block a user