Suppress memorysanitizer for code that calls LAPACK kernels.

PiperOrigin-RevId: 420325456
This commit is contained in:
Peter Hawkins 2022-01-07 10:47:32 -08:00 committed by jax authors
parent 712ab66f28
commit 548b9446ef
2 changed files with 73 additions and 43 deletions

View File

@ -147,6 +147,7 @@ cc_library(
hdrs = ["lapack_kernels.h"],
deps = [
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:dynamic_annotations",
],
)

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <limits>
#include "absl/base/attributes.h"
#include "absl/base/dynamic_annotations.h"
namespace jax {
@ -533,7 +534,8 @@ typename RealGeev<T>::FnType* RealGeev<T>::fn = nullptr;
template <typename T>
void RealGeev<T>::Kernel(void* out_tuple, void** data) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n = *(reinterpret_cast<int32_t*>(data[1]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
char jobvl = *(reinterpret_cast<uint8_t*>(data[2]));
char jobvr = *(reinterpret_cast<uint8_t*>(data[3]));
@ -553,26 +555,33 @@ void RealGeev<T>::Kernel(void* out_tuple, void** data) {
// TODO(phawkins): preallocate workspace using XLA.
T work_query;
int lwork = -1;
fn(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n, vr_work, &n,
&work_query, &lwork, info_out);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
vr_work, &n_int, &work_query, &lwork, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
lwork = static_cast<int>(work_query);
T* work = new T[lwork];
for (int i = 0; i < b; ++i) {
std::memcpy(a_work, a_in,
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
fn(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n, vr_work, &n,
work, &lwork, info_out);
size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
vr_work, &n_int, work, &lwork, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
if (info_out[0] == 0) {
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
}
a_in += static_cast<int64_t>(n) * n;
a_in += n * n;
wr_out += n;
wi_out += n;
vl_out += static_cast<int64_t>(n) * n;
vr_out += static_cast<int64_t>(n) * n;
vl_out += n * n;
vr_out += n * n;
++info_out;
}
delete[] work;
@ -584,7 +593,8 @@ typename ComplexGeev<T>::FnType* ComplexGeev<T>::fn = nullptr;
template <typename T>
void ComplexGeev<T>::Kernel(void* out_tuple, void** data) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n = *(reinterpret_cast<int32_t*>(data[1]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
char jobvl = *(reinterpret_cast<uint8_t*>(data[2]));
char jobvr = *(reinterpret_cast<uint8_t*>(data[3]));
@ -603,21 +613,26 @@ void ComplexGeev<T>::Kernel(void* out_tuple, void** data) {
// TODO(phawkins): preallocate workspace using XLA.
T work_query;
int lwork = -1;
fn(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n, &work_query,
&lwork, r_work, info_out);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
&n_int, &work_query, &lwork, r_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
lwork = static_cast<int>(work_query.real());
T* work = new T[lwork];
for (int i = 0; i < b; ++i) {
std::memcpy(a_work, a_in,
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
fn(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n, work,
&lwork, r_work, info_out);
a_in += static_cast<int64_t>(n) * n;
size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size);
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
&n_int, work, &lwork, r_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
a_in += n * n;
w_out += n;
vl_out += static_cast<int64_t>(n) * n;
vr_out += static_cast<int64_t>(n) * n;
vl_out += n * n;
vr_out += n * n;
info_out += 1;
}
delete[] work;
@ -636,7 +651,8 @@ typename RealGees<T>::FnType* RealGees<T>::fn = nullptr;
template <typename T>
void RealGees<T>::Kernel(void* out_tuple, void** data) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n = *(reinterpret_cast<int32_t*>(data[1]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
char jobvs = *(reinterpret_cast<uint8_t*>(data[2]));
char sort = *(reinterpret_cast<uint8_t*>(data[3]));
@ -659,22 +675,29 @@ void RealGees<T>::Kernel(void* out_tuple, void** data) {
T work_query;
int lwork = -1;
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, wr_out, wi_out, vs_out,
&n, &work_query, &lwork, b_work, info_out);
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, wr_out, wi_out,
vs_out, &n_int, &work_query, &lwork, b_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
lwork = static_cast<int>(work_query);
T* work = new T[lwork];
for (int i = 0; i < b; ++i) {
std::memcpy(a_work, a_in,
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, wr_out, wi_out, vs_out,
&n, work, &lwork, b_work, info_out);
size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size);
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, wr_out, wi_out,
vs_out, &n_int, work, &lwork, b_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
a_in += static_cast<int64_t>(n) * n;
a_work += static_cast<int64_t>(n) * n;
a_in += n * n;
a_work += n * n;
wr_out += n;
wi_out += n;
vs_out += static_cast<int64_t>(n) * n;
vs_out += n * n;
++sdim_out;
++info_out;
}
@ -688,7 +711,8 @@ typename ComplexGees<T>::FnType* ComplexGees<T>::fn = nullptr;
template <typename T>
void ComplexGees<T>::Kernel(void* out_tuple, void** data) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int n = *(reinterpret_cast<int32_t*>(data[1]));
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
int64_t n = n_int;
char jobvs = *(reinterpret_cast<uint8_t*>(data[2]));
char sort = *(reinterpret_cast<uint8_t*>(data[3]));
@ -706,26 +730,31 @@ void ComplexGees<T>::Kernel(void* out_tuple, void** data) {
int* sdim_out = reinterpret_cast<int*>(out[4]);
int* info_out = reinterpret_cast<int*>(out[5]);
bool* b_work;
if (sort == 'N') b_work = new bool[n];
bool* b_work = nullptr;
if (sort != 'N') b_work = new bool[n];
T work_query;
int lwork = -1;
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, w_out, vs_out, &n,
&work_query, &lwork, r_work, b_work, info_out);
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, w_out, vs_out,
&n_int, &work_query, &lwork, r_work, b_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
lwork = static_cast<int>(work_query.real());
T* work = new T[lwork];
for (int i = 0; i < b; ++i) {
std::memcpy(a_work, a_in,
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, w_out, vs_out, &n, work,
&lwork, r_work, b_work, info_out);
size_t a_size = n * n * sizeof(T);
std::memcpy(a_work, a_in, a_size);
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, w_out, vs_out,
&n_int, work, &lwork, r_work, b_work, info_out);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int));
a_in += static_cast<int64_t>(n) * n;
a_work += static_cast<int64_t>(n) * n;
a_in += n * n;
a_work += n * n;
w_out += n;
vs_out += static_cast<int64_t>(n) * n;
vs_out += n * n;
++info_out;
++sdim_out;
}