mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Suppress memorysanitizer for code that calls LAPACK kernels.
PiperOrigin-RevId: 420325456
This commit is contained in:
parent
712ab66f28
commit
548b9446ef
@ -147,6 +147,7 @@ cc_library(
|
||||
hdrs = ["lapack_kernels.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <limits>
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
@ -533,7 +534,8 @@ typename RealGeev<T>::FnType* RealGeev<T>::fn = nullptr;
|
||||
template <typename T>
|
||||
void RealGeev<T>::Kernel(void* out_tuple, void** data) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
char jobvl = *(reinterpret_cast<uint8_t*>(data[2]));
|
||||
char jobvr = *(reinterpret_cast<uint8_t*>(data[3]));
|
||||
|
||||
@ -553,26 +555,33 @@ void RealGeev<T>::Kernel(void* out_tuple, void** data) {
|
||||
// TODO(phawkins): preallocate workspace using XLA.
|
||||
T work_query;
|
||||
int lwork = -1;
|
||||
fn(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n, vr_work, &n,
|
||||
&work_query, &lwork, info_out);
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
|
||||
vr_work, &n_int, &work_query, &lwork, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
||||
lwork = static_cast<int>(work_query);
|
||||
T* work = new T[lwork];
|
||||
|
||||
for (int i = 0; i < b; ++i) {
|
||||
std::memcpy(a_work, a_in,
|
||||
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
|
||||
fn(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n, vr_work, &n,
|
||||
work, &lwork, info_out);
|
||||
size_t a_size = n * n * sizeof(T);
|
||||
std::memcpy(a_work, a_in, a_size);
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
|
||||
vr_work, &n_int, work, &lwork, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
if (info_out[0] == 0) {
|
||||
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
|
||||
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
|
||||
}
|
||||
|
||||
a_in += static_cast<int64_t>(n) * n;
|
||||
a_in += n * n;
|
||||
wr_out += n;
|
||||
wi_out += n;
|
||||
vl_out += static_cast<int64_t>(n) * n;
|
||||
vr_out += static_cast<int64_t>(n) * n;
|
||||
vl_out += n * n;
|
||||
vr_out += n * n;
|
||||
++info_out;
|
||||
}
|
||||
delete[] work;
|
||||
@ -584,7 +593,8 @@ typename ComplexGeev<T>::FnType* ComplexGeev<T>::fn = nullptr;
|
||||
template <typename T>
|
||||
void ComplexGeev<T>::Kernel(void* out_tuple, void** data) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
char jobvl = *(reinterpret_cast<uint8_t*>(data[2]));
|
||||
char jobvr = *(reinterpret_cast<uint8_t*>(data[3]));
|
||||
|
||||
@ -603,21 +613,26 @@ void ComplexGeev<T>::Kernel(void* out_tuple, void** data) {
|
||||
// TODO(phawkins): preallocate workspace using XLA.
|
||||
T work_query;
|
||||
int lwork = -1;
|
||||
fn(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n, &work_query,
|
||||
&lwork, r_work, info_out);
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
||||
&n_int, &work_query, &lwork, r_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
||||
lwork = static_cast<int>(work_query.real());
|
||||
T* work = new T[lwork];
|
||||
|
||||
for (int i = 0; i < b; ++i) {
|
||||
std::memcpy(a_work, a_in,
|
||||
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
|
||||
fn(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n, work,
|
||||
&lwork, r_work, info_out);
|
||||
|
||||
a_in += static_cast<int64_t>(n) * n;
|
||||
size_t a_size = n * n * sizeof(T);
|
||||
std::memcpy(a_work, a_in, a_size);
|
||||
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
||||
&n_int, work, &lwork, r_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
a_in += n * n;
|
||||
w_out += n;
|
||||
vl_out += static_cast<int64_t>(n) * n;
|
||||
vr_out += static_cast<int64_t>(n) * n;
|
||||
vl_out += n * n;
|
||||
vr_out += n * n;
|
||||
info_out += 1;
|
||||
}
|
||||
delete[] work;
|
||||
@ -636,7 +651,8 @@ typename RealGees<T>::FnType* RealGees<T>::fn = nullptr;
|
||||
template <typename T>
|
||||
void RealGees<T>::Kernel(void* out_tuple, void** data) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
char jobvs = *(reinterpret_cast<uint8_t*>(data[2]));
|
||||
char sort = *(reinterpret_cast<uint8_t*>(data[3]));
|
||||
|
||||
@ -659,22 +675,29 @@ void RealGees<T>::Kernel(void* out_tuple, void** data) {
|
||||
|
||||
T work_query;
|
||||
int lwork = -1;
|
||||
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, wr_out, wi_out, vs_out,
|
||||
&n, &work_query, &lwork, b_work, info_out);
|
||||
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, wr_out, wi_out,
|
||||
vs_out, &n_int, &work_query, &lwork, b_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
||||
lwork = static_cast<int>(work_query);
|
||||
T* work = new T[lwork];
|
||||
|
||||
for (int i = 0; i < b; ++i) {
|
||||
std::memcpy(a_work, a_in,
|
||||
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
|
||||
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, wr_out, wi_out, vs_out,
|
||||
&n, work, &lwork, b_work, info_out);
|
||||
size_t a_size = n * n * sizeof(T);
|
||||
std::memcpy(a_work, a_in, a_size);
|
||||
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, wr_out, wi_out,
|
||||
vs_out, &n_int, work, &lwork, b_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
|
||||
a_in += static_cast<int64_t>(n) * n;
|
||||
a_work += static_cast<int64_t>(n) * n;
|
||||
a_in += n * n;
|
||||
a_work += n * n;
|
||||
wr_out += n;
|
||||
wi_out += n;
|
||||
vs_out += static_cast<int64_t>(n) * n;
|
||||
vs_out += n * n;
|
||||
++sdim_out;
|
||||
++info_out;
|
||||
}
|
||||
@ -688,7 +711,8 @@ typename ComplexGees<T>::FnType* ComplexGees<T>::fn = nullptr;
|
||||
template <typename T>
|
||||
void ComplexGees<T>::Kernel(void* out_tuple, void** data) {
|
||||
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
||||
int n = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
||||
int64_t n = n_int;
|
||||
char jobvs = *(reinterpret_cast<uint8_t*>(data[2]));
|
||||
char sort = *(reinterpret_cast<uint8_t*>(data[3]));
|
||||
|
||||
@ -706,26 +730,31 @@ void ComplexGees<T>::Kernel(void* out_tuple, void** data) {
|
||||
int* sdim_out = reinterpret_cast<int*>(out[4]);
|
||||
int* info_out = reinterpret_cast<int*>(out[5]);
|
||||
|
||||
bool* b_work;
|
||||
if (sort == 'N') b_work = new bool[n];
|
||||
bool* b_work = nullptr;
|
||||
if (sort != 'N') b_work = new bool[n];
|
||||
|
||||
T work_query;
|
||||
int lwork = -1;
|
||||
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, w_out, vs_out, &n,
|
||||
&work_query, &lwork, r_work, b_work, info_out);
|
||||
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, w_out, vs_out,
|
||||
&n_int, &work_query, &lwork, r_work, b_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
||||
lwork = static_cast<int>(work_query.real());
|
||||
T* work = new T[lwork];
|
||||
|
||||
for (int i = 0; i < b; ++i) {
|
||||
std::memcpy(a_work, a_in,
|
||||
static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T));
|
||||
fn(&jobvs, &sort, select, &n, a_work, &n, sdim_out, w_out, vs_out, &n, work,
|
||||
&lwork, r_work, b_work, info_out);
|
||||
size_t a_size = n * n * sizeof(T);
|
||||
std::memcpy(a_work, a_in, a_size);
|
||||
fn(&jobvs, &sort, select, &n_int, a_work, &n_int, sdim_out, w_out, vs_out,
|
||||
&n_int, work, &lwork, r_work, b_work, info_out);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int));
|
||||
|
||||
a_in += static_cast<int64_t>(n) * n;
|
||||
a_work += static_cast<int64_t>(n) * n;
|
||||
a_in += n * n;
|
||||
a_work += n * n;
|
||||
w_out += n;
|
||||
vs_out += static_cast<int64_t>(n) * n;
|
||||
vs_out += n * n;
|
||||
++info_out;
|
||||
++sdim_out;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user