2022-09-22 12:26:48 -07:00
|
|
|
/* Copyright 2021 The JAX Authors.
|
2021-09-03 10:03:25 -07:00
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
==============================================================================*/
|
|
|
|
|
2022-10-24 10:02:12 -07:00
|
|
|
#include "jaxlib/cpu/lapack_kernels.h"
|
2021-09-03 10:03:25 -07:00
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
#include <algorithm>
|
2021-09-03 10:03:25 -07:00
|
|
|
#include <cmath>
|
2024-06-13 05:43:40 -07:00
|
|
|
#include <complex>
|
|
|
|
#include <cstddef>
|
2023-11-20 17:27:42 -08:00
|
|
|
#include <cstdint>
|
2024-06-13 05:43:40 -07:00
|
|
|
#include <functional>
|
2021-09-03 10:03:25 -07:00
|
|
|
#include <limits>
|
2024-06-13 05:43:40 -07:00
|
|
|
#include <optional>
|
|
|
|
#include <stdexcept>
|
|
|
|
#include <string>
|
|
|
|
#include <tuple>
|
|
|
|
#include <type_traits>
|
2021-09-03 10:03:25 -07:00
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
#include "absl/algorithm/container.h"
|
2022-01-07 10:47:32 -08:00
|
|
|
#include "absl/base/dynamic_annotations.h"
|
2024-06-27 09:40:33 -07:00
|
|
|
#include "absl/status/status.h"
|
|
|
|
#include "absl/status/statusor.h"
|
2024-06-13 05:43:40 -07:00
|
|
|
#include "absl/strings/str_format.h"
|
|
|
|
#include "xla/ffi/api/c_api.h"
|
|
|
|
#include "xla/ffi/api/ffi.h"
|
|
|
|
|
|
|
|
static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),
|
|
|
|
"Expected LAPACK integers to be 32-bit");
|
|
|
|
|
|
|
|
namespace ffi = xla::ffi;
|
2021-09-03 10:03:25 -07:00
|
|
|
|
2024-06-27 09:40:33 -07:00
|
|
|
// TODO(danfm): These macros and the casting functions should be moved to a
|
|
|
|
// separate header for use in other FFI kernels.
|
|
|
|
#define ASSIGN_OR_RETURN_FFI_ERROR(lhs, rhs) \
|
|
|
|
if (!rhs.ok()) { \
|
|
|
|
return ffi::Error(static_cast<XLA_FFI_Error_Code>(rhs.status().code()), \
|
|
|
|
std::string(rhs.status().message())); \
|
|
|
|
} \
|
|
|
|
lhs = rhs.value()
|
|
|
|
|
|
|
|
#define RETURN_IF_FFI_ERROR(...) \
|
|
|
|
do { \
|
|
|
|
ffi::Error err = (__VA_ARGS__); \
|
|
|
|
if (err.failure()) { \
|
|
|
|
return err; \
|
|
|
|
} \
|
|
|
|
} while (0)
|
|
|
|
|
2024-01-10 17:50:56 +02:00
|
|
|
namespace {
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
template <typename T>
|
2024-06-27 09:40:33 -07:00
|
|
|
inline absl::StatusOr<T> MaybeCastNoOverflow(
|
|
|
|
int64_t value, const std::string& source = __FILE__) {
|
2024-06-13 05:43:40 -07:00
|
|
|
if constexpr (sizeof(T) == sizeof(int64_t)) {
|
2024-01-10 17:50:56 +02:00
|
|
|
return value;
|
|
|
|
} else {
|
2024-06-13 05:43:40 -07:00
|
|
|
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
|
2024-06-27 09:40:33 -07:00
|
|
|
return absl::InvalidArgumentError(
|
2024-06-13 05:43:40 -07:00
|
|
|
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
|
|
|
|
"value of the desired type",
|
2024-06-27 09:40:33 -07:00
|
|
|
source, value));
|
2024-01-10 17:50:56 +02:00
|
|
|
}
|
2024-06-13 05:43:40 -07:00
|
|
|
return static_cast<T>(value);
|
2024-01-10 17:50:56 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
template <typename T>
|
2024-06-27 09:40:33 -07:00
|
|
|
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
|
|
|
|
auto result = MaybeCastNoOverflow<T>(value, source);
|
|
|
|
if (!result.ok()) {
|
|
|
|
throw std::overflow_error{std::string(result.status().message())};
|
|
|
|
}
|
|
|
|
return result.value();
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
ffi::Error CheckMatrixDimensions(ffi::Span<T> dims) {
|
2024-06-13 05:43:40 -07:00
|
|
|
if (dims.size() < 2) {
|
2024-06-27 09:40:33 -07:00
|
|
|
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
|
|
|
"Matrix must have at least 2 dimensions");
|
2024-06-13 05:43:40 -07:00
|
|
|
}
|
2024-06-27 09:40:33 -07:00
|
|
|
return ffi::Error::Success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
|
2024-06-13 05:43:40 -07:00
|
|
|
auto matrix_dims = dims.last(2);
|
|
|
|
return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1,
|
|
|
|
std::multiplies<int64_t>()),
|
|
|
|
matrix_dims.front(), matrix_dims.back());
|
|
|
|
}
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
void CopyIfDiffBuffer(ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out) {
|
2024-07-09 11:06:54 -07:00
|
|
|
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
|
|
|
if (x.typed_data() != x_out->typed_data()) {
|
|
|
|
const auto x_size = batch_count * x_rows * x_cols;
|
|
|
|
std::copy_n(x.typed_data(), x_size, x_out->typed_data());
|
|
|
|
}
|
2024-01-10 17:50:56 +02:00
|
|
|
}
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \
|
|
|
|
std::optional<type> xla::ffi::AttrDecoding<type>::Decode( \
|
|
|
|
XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \
|
|
|
|
if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \
|
|
|
|
return diagnostic.Emit("Wrong attribute type: expected ") \
|
|
|
|
<< XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \
|
|
|
|
} \
|
|
|
|
auto* scalar = reinterpret_cast<XLA_FFI_Scalar*>(attr); \
|
|
|
|
if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \
|
|
|
|
return diagnostic.Emit("Wrong scalar data type: expected ") \
|
|
|
|
<< XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \
|
|
|
|
} \
|
|
|
|
auto underlying = \
|
|
|
|
*reinterpret_cast<std::underlying_type_t<type>*>(scalar->value); \
|
|
|
|
return static_cast<type>(underlying); \
|
|
|
|
}
|
|
|
|
|
|
|
|
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side);
|
|
|
|
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
|
|
|
|
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
|
|
|
|
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
|
2024-07-08 05:19:07 -07:00
|
|
|
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
|
2024-06-13 05:43:40 -07:00
|
|
|
|
|
|
|
#undef REGISTER_CHAR_ENUM_ATTR_DECODING
|
|
|
|
|
2021-09-03 10:03:25 -07:00
|
|
|
namespace jax {
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== Triangular System Solver ==//
|
|
|
|
|
|
|
|
// lapack trsm
|
2021-09-03 10:03:25 -07:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename Trsm<T>::FnType* Trsm<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void Trsm<T>::Kernel(void* out, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int32_t left_side = *reinterpret_cast<int32_t*>(data[0]);
|
|
|
|
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
|
|
|
|
int32_t trans_a = *reinterpret_cast<int32_t*>(data[2]);
|
|
|
|
int32_t diag = *reinterpret_cast<int32_t*>(data[3]);
|
|
|
|
int m = *reinterpret_cast<int32_t*>(data[4]);
|
|
|
|
int n = *reinterpret_cast<int32_t*>(data[5]);
|
|
|
|
int batch = *reinterpret_cast<int32_t*>(data[6]);
|
|
|
|
T* alpha = reinterpret_cast<T*>(data[7]);
|
|
|
|
T* a = reinterpret_cast<T*>(data[8]);
|
|
|
|
T* b = reinterpret_cast<T*>(data[9]);
|
|
|
|
|
|
|
|
T* x = reinterpret_cast<T*>(out);
|
|
|
|
if (x != b) {
|
|
|
|
std::memcpy(x, b,
|
|
|
|
static_cast<int64_t>(batch) * static_cast<int64_t>(m) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
char cside = left_side ? 'L' : 'R';
|
|
|
|
char cuplo = lower ? 'L' : 'U';
|
|
|
|
char ctransa = 'N';
|
|
|
|
if (trans_a == 1) {
|
|
|
|
ctransa = 'T';
|
|
|
|
} else if (trans_a == 2) {
|
|
|
|
ctransa = 'C';
|
|
|
|
}
|
|
|
|
char cdiag = diag ? 'U' : 'N';
|
|
|
|
int lda = left_side ? m : n;
|
|
|
|
int ldb = m;
|
|
|
|
|
|
|
|
int64_t x_plus = static_cast<int64_t>(m) * static_cast<int64_t>(n);
|
|
|
|
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(lda);
|
|
|
|
|
|
|
|
for (int i = 0; i < batch; ++i) {
|
|
|
|
fn(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb);
|
|
|
|
x += x_plus;
|
|
|
|
a += a_plus;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct Trsm<float>;
|
|
|
|
template struct Trsm<double>;
|
|
|
|
template struct Trsm<std::complex<float>>;
|
|
|
|
template struct Trsm<std::complex<double>>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== LU Decomposition ==//
|
|
|
|
|
|
|
|
// lapack getrf
|
2021-09-03 10:03:25 -07:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void Getrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int m = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[3]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
int* ipiv = reinterpret_cast<int*>(out[1]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[2]);
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(m) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&m, &n, a_out, &m, ipiv, info);
|
|
|
|
a_out += static_cast<int64_t>(m) * static_cast<int64_t>(n);
|
|
|
|
ipiv += std::min(m, n);
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct Getrf<float>;
|
|
|
|
template struct Getrf<double>;
|
|
|
|
template struct Getrf<std::complex<float>>;
|
|
|
|
template struct Getrf<std::complex<double>>;
|
|
|
|
|
2024-06-19 17:30:50 -07:00
|
|
|
// FFI Kernel
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
ffi::Error LuDecomposition<dtype>::Kernel(
|
|
|
|
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
|
|
|
ffi::ResultBuffer<LapackIntDtype> ipiv,
|
|
|
|
ffi::ResultBuffer<LapackIntDtype> info) {
|
2024-07-09 11:06:54 -07:00
|
|
|
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions()));
|
|
|
|
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
|
|
|
auto* x_out_data = x_out->typed_data();
|
|
|
|
auto* ipiv_data = ipiv->typed_data();
|
|
|
|
auto* info_data = info->typed_data();
|
2024-06-19 17:30:50 -07:00
|
|
|
|
|
|
|
CopyIfDiffBuffer(x, x_out);
|
|
|
|
|
2024-06-27 09:40:33 -07:00
|
|
|
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
|
|
|
|
MaybeCastNoOverflow<lapack_int>(x_rows));
|
|
|
|
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
|
|
|
|
MaybeCastNoOverflow<lapack_int>(x_cols));
|
2024-06-19 17:30:50 -07:00
|
|
|
auto x_leading_dim_v = x_rows_v;
|
|
|
|
|
|
|
|
const int64_t x_out_step{x_rows * x_cols};
|
|
|
|
const int64_t ipiv_step{std::min(x_rows, x_cols)};
|
|
|
|
for (int64_t i = 0; i < batch_count; ++i) {
|
|
|
|
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, ipiv_data,
|
|
|
|
info_data);
|
|
|
|
x_out_data += x_out_step;
|
|
|
|
ipiv_data += ipiv_step;
|
|
|
|
++info_data;
|
|
|
|
}
|
|
|
|
return ffi::Error::Success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct LuDecomposition<ffi::DataType::F32>;
|
|
|
|
template struct LuDecomposition<ffi::DataType::F64>;
|
|
|
|
template struct LuDecomposition<ffi::DataType::C64>;
|
|
|
|
template struct LuDecomposition<ffi::DataType::C128>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== QR Factorization ==//
|
|
|
|
|
|
|
|
// lapack geqrf
|
2021-09-03 10:03:25 -07:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename Geqrf<T>::FnType* Geqrf<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void Geqrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int m = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
int lwork = *(reinterpret_cast<int32_t*>(data[3]));
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[4]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
T* tau = reinterpret_cast<T*>(out[1]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[2]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[3]);
|
|
|
|
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(m) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&m, &n, a_out, &m, tau, work, &lwork, info);
|
|
|
|
a_out += static_cast<int64_t>(m) * static_cast<int64_t>(n);
|
|
|
|
tau += std::min(m, n);
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
int64_t Geqrf<T>::Workspace(lapack_int m, lapack_int n) {
|
|
|
|
T work = 0;
|
|
|
|
lapack_int lwork = -1;
|
|
|
|
lapack_int info = 0;
|
|
|
|
fn(&m, &n, nullptr, &m, nullptr, &work, &lwork, &info);
|
|
|
|
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct Geqrf<float>;
|
|
|
|
template struct Geqrf<double>;
|
|
|
|
template struct Geqrf<std::complex<float>>;
|
|
|
|
template struct Geqrf<std::complex<double>>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== Orthogonal QR ==//
|
|
|
|
//== Computes orthogonal matrix Q from QR Decomposition ==//
|
|
|
|
|
|
|
|
// lapack orgqr
|
2021-09-03 10:03:25 -07:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename Orgqr<T>::FnType* Orgqr<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void Orgqr<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int m = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
int k = *(reinterpret_cast<int32_t*>(data[3]));
|
|
|
|
int lwork = *(reinterpret_cast<int32_t*>(data[4]));
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[5]);
|
|
|
|
T* tau = reinterpret_cast<T*>(data[6]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[1]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[2]);
|
|
|
|
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(m) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&m, &n, &k, a_out, &m, tau, work, &lwork, info);
|
|
|
|
a_out += static_cast<int64_t>(m) * static_cast<int64_t>(n);
|
|
|
|
tau += k;
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
int64_t Orgqr<T>::Workspace(int m, int n, int k) {
|
|
|
|
T work = 0;
|
|
|
|
int lwork = -1;
|
|
|
|
int info = 0;
|
|
|
|
fn(&m, &n, &k, nullptr, &m, nullptr, &work, &lwork, &info);
|
|
|
|
return info ? -1 : static_cast<int64_t>(std::real(work));
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct Orgqr<float>;
|
|
|
|
template struct Orgqr<double>;
|
|
|
|
template struct Orgqr<std::complex<float>>;
|
|
|
|
template struct Orgqr<std::complex<double>>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== Cholesky Factorization ==//
|
|
|
|
|
|
|
|
// lapack potrf
|
2021-09-03 10:03:25 -07:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename Potrf<T>::FnType* Potrf<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void Potrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[3]);
|
|
|
|
char uplo = lower ? 'L' : 'U';
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[1]);
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(n) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&uplo, &n, a_out, &n, info);
|
|
|
|
a_out += static_cast<int64_t>(n) * static_cast<int64_t>(n);
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct Potrf<float>;
|
|
|
|
template struct Potrf<double>;
|
|
|
|
template struct Potrf<std::complex<float>>;
|
|
|
|
template struct Potrf<std::complex<double>>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
// FFI Kernel
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
ffi::Error CholeskyFactorization<dtype>::Kernel(
|
|
|
|
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
|
|
|
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
|
2024-07-09 11:06:54 -07:00
|
|
|
RETURN_IF_FFI_ERROR(CheckMatrixDimensions(x.dimensions()));
|
|
|
|
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
|
|
|
auto* x_out_data = x_out->typed_data();
|
|
|
|
auto* info_data = info->typed_data();
|
2024-06-13 05:43:40 -07:00
|
|
|
|
|
|
|
CopyIfDiffBuffer(x, x_out);
|
|
|
|
|
|
|
|
auto uplo_v = static_cast<char>(uplo);
|
2024-07-09 11:06:54 -07:00
|
|
|
ASSIGN_OR_RETURN_FFI_ERROR(
|
|
|
|
auto x_order_v, MaybeCastNoOverflow<lapack_int>(x.dimensions().back()));
|
2024-06-13 05:43:40 -07:00
|
|
|
auto x_leading_dim_v = x_order_v;
|
|
|
|
|
|
|
|
const int64_t x_out_step{x_rows * x_cols};
|
|
|
|
for (int64_t i = 0; i < batch_count; ++i) {
|
|
|
|
fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, info_data);
|
|
|
|
x_out_data += x_out_step;
|
|
|
|
++info_data;
|
|
|
|
}
|
|
|
|
return ffi::Error::Success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct CholeskyFactorization<ffi::DataType::F32>;
|
|
|
|
template struct CholeskyFactorization<ffi::DataType::F64>;
|
|
|
|
template struct CholeskyFactorization<ffi::DataType::C64>;
|
|
|
|
template struct CholeskyFactorization<ffi::DataType::C128>;
|
|
|
|
|
|
|
|
//== Singular Value Decomposition (SVD) ==//
|
|
|
|
//== using a divide and conquer method ==//
|
|
|
|
|
|
|
|
// lapack gesdd
|
2021-09-03 10:03:25 -07:00
|
|
|
|
|
|
|
static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) {
|
|
|
|
if (!job_opt_compute_uv) {
|
|
|
|
return 'N';
|
|
|
|
} else if (!job_opt_full_matrices) {
|
|
|
|
return 'S';
|
|
|
|
}
|
|
|
|
return 'A';
|
|
|
|
}
|
|
|
|
|
|
|
|
lapack_int GesddIworkSize(int64_t m, int64_t n) {
|
2024-06-13 05:43:40 -07:00
|
|
|
return CastNoOverflow<lapack_int>(8 * std::min(m, n), "gesdd iwork");
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename RealGesdd<T>::FnType* RealGesdd<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void RealGesdd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int32_t job_opt_full_matrices = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int32_t job_opt_compute_uv = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
int m = *(reinterpret_cast<int32_t*>(data[3]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[4]));
|
|
|
|
int lwork = *(reinterpret_cast<int32_t*>(data[5]));
|
|
|
|
T* a_in = reinterpret_cast<T*>(data[6]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
T* s = reinterpret_cast<T*>(out[1]);
|
|
|
|
T* u = reinterpret_cast<T*>(out[2]);
|
|
|
|
T* vt = reinterpret_cast<T*>(out[3]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[4]);
|
|
|
|
int* iwork = reinterpret_cast<int*>(out[5]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[6]);
|
|
|
|
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(m) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices);
|
|
|
|
|
|
|
|
int lda = m;
|
|
|
|
int ldu = m;
|
|
|
|
int tdu = job_opt_full_matrices ? m : std::min(m, n);
|
|
|
|
int ldvt = job_opt_full_matrices ? n : std::min(m, n);
|
|
|
|
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork,
|
|
|
|
info);
|
|
|
|
a_out += static_cast<int64_t>(m) * n;
|
|
|
|
s += std::min(m, n);
|
|
|
|
u += static_cast<int64_t>(m) * tdu;
|
|
|
|
vt += static_cast<int64_t>(ldvt) * n;
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
int64_t RealGesdd<T>::Workspace(lapack_int m, lapack_int n,
|
|
|
|
bool job_opt_compute_uv,
|
|
|
|
bool job_opt_full_matrices) {
|
|
|
|
T work = 0;
|
|
|
|
int lwork = -1;
|
|
|
|
int info = 0;
|
|
|
|
int ldvt = job_opt_full_matrices ? n : std::min(m, n);
|
|
|
|
char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices);
|
|
|
|
fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work,
|
|
|
|
&lwork, nullptr, &info);
|
|
|
|
return info ? -1 : static_cast<int>(work);
|
|
|
|
}
|
|
|
|
|
|
|
|
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) {
|
|
|
|
int64_t mn = std::min(m, n);
|
|
|
|
if (compute_uv == 0) {
|
2024-06-13 05:43:40 -07:00
|
|
|
return CastNoOverflow<lapack_int>(7 * mn, "complex gesdd rwork");
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
|
|
|
int64_t mx = std::max(m, n);
|
2024-06-13 05:43:40 -07:00
|
|
|
return CastNoOverflow<lapack_int>(
|
|
|
|
std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn),
|
|
|
|
"complex gesdd rwork");
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename ComplexGesdd<T>::FnType* ComplexGesdd<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void ComplexGesdd<T>::Kernel(void* out_tuple, void** data,
|
|
|
|
XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int32_t job_opt_full_matrices = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int32_t job_opt_compute_uv = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
int m = *(reinterpret_cast<int32_t*>(data[3]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[4]));
|
|
|
|
int lwork = *(reinterpret_cast<int32_t*>(data[5]));
|
|
|
|
T* a_in = reinterpret_cast<T*>(data[6]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
typename T::value_type* s = reinterpret_cast<typename T::value_type*>(out[1]);
|
|
|
|
T* u = reinterpret_cast<T*>(out[2]);
|
|
|
|
T* vt = reinterpret_cast<T*>(out[3]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[4]);
|
|
|
|
int* iwork = reinterpret_cast<int*>(out[5]);
|
|
|
|
typename T::value_type* rwork =
|
|
|
|
reinterpret_cast<typename T::value_type*>(out[6]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[7]);
|
|
|
|
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(m) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices);
|
|
|
|
|
|
|
|
int lda = m;
|
|
|
|
int ldu = m;
|
|
|
|
int tdu = job_opt_full_matrices ? m : std::min(m, n);
|
|
|
|
int ldvt = job_opt_full_matrices ? n : std::min(m, n);
|
|
|
|
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork,
|
|
|
|
iwork, info);
|
|
|
|
a_out += static_cast<int64_t>(m) * n;
|
|
|
|
s += std::min(m, n);
|
|
|
|
u += static_cast<int64_t>(m) * tdu;
|
|
|
|
vt += static_cast<int64_t>(ldvt) * n;
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
int64_t ComplexGesdd<T>::Workspace(lapack_int m, lapack_int n,
|
|
|
|
bool job_opt_compute_uv,
|
|
|
|
bool job_opt_full_matrices) {
|
|
|
|
T work = 0;
|
|
|
|
int lwork = -1;
|
|
|
|
int info = 0;
|
|
|
|
int ldvt = job_opt_full_matrices ? n : std::min(m, n);
|
|
|
|
char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices);
|
|
|
|
fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work,
|
|
|
|
&lwork, nullptr, nullptr, &info);
|
|
|
|
return info ? -1 : static_cast<int>(work.real());
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct RealGesdd<float>;
|
|
|
|
template struct RealGesdd<double>;
|
|
|
|
template struct ComplexGesdd<std::complex<float>>;
|
|
|
|
template struct ComplexGesdd<std::complex<double>>;
|
|
|
|
|
2024-07-08 05:19:07 -07:00
|
|
|
// FFI Kernel
|
|
|
|
|
|
|
|
namespace internal {
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
using RealBufferForComplexOrNull =
|
|
|
|
std::conditional_t<ffi::IsComplexType<dtype>(),
|
|
|
|
ffi::ResultBuffer<ffi::ToReal(dtype)>, std::nullptr_t>;
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
static ffi::Error SvdKernel(
|
|
|
|
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
|
|
|
ffi::ResultBuffer<ffi::ToReal(dtype)> singular_values,
|
|
|
|
ffi::ResultBuffer<dtype> u, ffi::ResultBuffer<dtype> vt,
|
|
|
|
ffi::ResultBuffer<LapackIntDtype> info,
|
|
|
|
ffi::ResultBuffer<LapackIntDtype> iwork, ffi::ResultBuffer<dtype> work,
|
|
|
|
svd::ComputationMode mode, RealBufferForComplexOrNull<dtype> rwork) {
|
|
|
|
if (mode == svd::ComputationMode::kComputeVtOverwriteXPartialU) [[unlikely]] {
|
|
|
|
return ffi::Error(
|
|
|
|
XLA_FFI_Error_Code_UNIMPLEMENTED,
|
|
|
|
"Current implementation does not support this computation mode");
|
|
|
|
}
|
2024-07-09 11:06:54 -07:00
|
|
|
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
|
|
|
auto* x_out_data = x_out->typed_data();
|
|
|
|
auto* singular_values_data = singular_values->typed_data();
|
|
|
|
auto* u_data = u->typed_data();
|
|
|
|
auto* vt_data = vt->typed_data();
|
|
|
|
auto* info_data = info->typed_data();
|
|
|
|
auto* iwork_data = iwork->typed_data();
|
|
|
|
auto* work_data = work->typed_data();
|
2024-07-08 05:19:07 -07:00
|
|
|
|
|
|
|
CopyIfDiffBuffer(x, x_out);
|
|
|
|
|
2024-07-09 03:57:11 -07:00
|
|
|
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
|
|
|
|
MaybeCastNoOverflow<lapack_int>(x_rows));
|
|
|
|
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
|
|
|
|
MaybeCastNoOverflow<lapack_int>(x_cols));
|
2024-07-08 05:19:07 -07:00
|
|
|
auto mode_v = static_cast<char>(mode);
|
2024-07-09 11:06:54 -07:00
|
|
|
ASSIGN_OR_RETURN_FFI_ERROR(
|
|
|
|
auto workspace_dim_v,
|
|
|
|
MaybeCastNoOverflow<lapack_int>(work->dimensions().back()));
|
2024-07-08 05:19:07 -07:00
|
|
|
auto x_leading_dim_v = x_rows_v;
|
|
|
|
auto u_leading_dim_v = x_rows_v;
|
|
|
|
|
2024-07-09 11:06:54 -07:00
|
|
|
auto u_dims = u->dimensions().last(2);
|
|
|
|
auto vt_dims = vt->dimensions().last(2);
|
2024-07-09 03:57:11 -07:00
|
|
|
ASSIGN_OR_RETURN_FFI_ERROR(auto vt_leading_dim_v,
|
|
|
|
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
|
2024-07-08 05:19:07 -07:00
|
|
|
|
|
|
|
const int64_t x_out_step{x_rows * x_cols};
|
2024-07-09 11:06:54 -07:00
|
|
|
const int64_t singular_values_step{singular_values->dimensions().back()};
|
2024-07-08 05:19:07 -07:00
|
|
|
const int64_t u_step{u_dims.front() * u_dims.back()};
|
|
|
|
const int64_t vt_step{vt_leading_dim_v * vt_dims.back()};
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < batch_count; ++i) {
|
|
|
|
if constexpr (ffi::IsComplexType<dtype>()) {
|
2024-07-09 11:06:54 -07:00
|
|
|
svd::SVDType<dtype>::fn(&mode_v, &x_rows_v, &x_cols_v, x_out_data,
|
|
|
|
&x_leading_dim_v, singular_values_data, u_data,
|
|
|
|
&u_leading_dim_v, vt_data, &vt_leading_dim_v,
|
|
|
|
work_data, &workspace_dim_v, rwork->typed_data(),
|
|
|
|
iwork_data, info_data);
|
2024-07-08 05:19:07 -07:00
|
|
|
} else {
|
|
|
|
svd::SVDType<dtype>::fn(&mode_v, &x_rows_v, &x_cols_v, x_out_data,
|
|
|
|
&x_leading_dim_v, singular_values_data, u_data,
|
|
|
|
&u_leading_dim_v, vt_data, &vt_leading_dim_v,
|
|
|
|
work_data, &workspace_dim_v, iwork_data,
|
|
|
|
info_data);
|
|
|
|
}
|
|
|
|
x_out_data += x_out_step;
|
|
|
|
singular_values_data += singular_values_step;
|
|
|
|
u_data += u_step;
|
|
|
|
vt_data += vt_step;
|
|
|
|
++info_data;
|
|
|
|
}
|
|
|
|
return ffi::Error::Success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
static int64_t SvdGetWorkspaceSize(lapack_int x_rows, lapack_int x_cols,
|
|
|
|
svd::ComputationMode mode) {
|
|
|
|
ffi::NativeType<dtype> optimal_size = {};
|
|
|
|
lapack_int info = 0;
|
|
|
|
lapack_int workspace_query = -1;
|
|
|
|
|
|
|
|
auto mode_v = static_cast<char>(mode);
|
|
|
|
auto x_leading_dim_v = x_rows;
|
|
|
|
auto u_leading_dim_v = x_rows;
|
|
|
|
auto vt_leading_dim_v = mode == svd::ComputationMode::kComputeFullUVt
|
|
|
|
|
|
|
|
? x_cols
|
|
|
|
: std::min(x_rows, x_cols);
|
|
|
|
if constexpr (ffi::IsComplexType<dtype>()) {
|
|
|
|
svd::SVDType<dtype>::fn(
|
|
|
|
&mode_v, &x_rows, &x_cols, nullptr, &x_leading_dim_v, nullptr, nullptr,
|
|
|
|
&u_leading_dim_v, nullptr, &vt_leading_dim_v, &optimal_size,
|
|
|
|
&workspace_query, nullptr, nullptr, &info);
|
|
|
|
} else {
|
|
|
|
svd::SVDType<dtype>::fn(&mode_v, &x_rows, &x_cols, nullptr,
|
|
|
|
&x_leading_dim_v, nullptr, nullptr,
|
|
|
|
&u_leading_dim_v, nullptr, &vt_leading_dim_v,
|
|
|
|
&optimal_size, &workspace_query, nullptr, &info);
|
|
|
|
}
|
|
|
|
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace internal
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
ffi::Error SingularValueDecomposition<dtype>::Kernel(
|
|
|
|
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
|
|
|
ffi::ResultBuffer<dtype> singular_values, ffi::ResultBuffer<dtype> u,
|
|
|
|
ffi::ResultBuffer<dtype> vt, ffi::ResultBuffer<LapackIntDtype> info,
|
|
|
|
ffi::ResultBuffer<LapackIntDtype> iwork, ffi::ResultBuffer<dtype> work,
|
|
|
|
svd::ComputationMode mode) {
|
|
|
|
return internal::SvdKernel<dtype>(x, x_out, singular_values, u, vt, info,
|
|
|
|
iwork, work, mode, nullptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
ffi::Error SingularValueDecompositionComplex<dtype>::Kernel(
|
|
|
|
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
|
|
|
ffi::ResultBuffer<ffi::ToReal(dtype)> singular_values,
|
|
|
|
ffi::ResultBuffer<dtype> u, ffi::ResultBuffer<dtype> vt,
|
|
|
|
ffi::ResultBuffer<LapackIntDtype> info,
|
|
|
|
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork,
|
|
|
|
ffi::ResultBuffer<LapackIntDtype> iwork, ffi::ResultBuffer<dtype> work,
|
|
|
|
svd::ComputationMode mode) {
|
|
|
|
return internal::SvdKernel<dtype>(x, x_out, singular_values, u, vt, info,
|
|
|
|
iwork, work, mode, rwork);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
int64_t SingularValueDecomposition<dtype>::GetWorkspaceSize(
|
|
|
|
lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) {
|
|
|
|
return internal::SvdGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <ffi::DataType dtype>
|
|
|
|
int64_t SingularValueDecompositionComplex<dtype>::GetWorkspaceSize(
|
|
|
|
lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) {
|
|
|
|
return internal::SvdGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
|
|
|
|
}
|
|
|
|
|
|
|
|
lapack_int svd::GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols,
|
|
|
|
svd::ComputationMode mode) {
|
|
|
|
const auto min_dim = std::min(x_rows, x_cols);
|
|
|
|
if (!ComputesUV(mode)) {
|
|
|
|
return CastNoOverflow<lapack_int>(7 * min_dim);
|
|
|
|
}
|
|
|
|
const auto max_dim = std::max(x_rows, x_cols);
|
|
|
|
return CastNoOverflow<lapack_int>(
|
|
|
|
std::max(5 * min_dim * min_dim + 5 * min_dim,
|
|
|
|
2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim));
|
|
|
|
}
|
|
|
|
|
|
|
|
lapack_int svd::GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols) {
|
|
|
|
return CastNoOverflow<lapack_int>(8 * std::min(x_rows, x_cols));
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct SingularValueDecomposition<ffi::DataType::F32>;
|
|
|
|
template struct SingularValueDecomposition<ffi::DataType::F64>;
|
|
|
|
template struct SingularValueDecompositionComplex<ffi::DataType::C64>;
|
|
|
|
template struct SingularValueDecompositionComplex<ffi::DataType::C128>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== Eigenvalues and eigenvectors ==//
|
|
|
|
|
|
|
|
// lapack syevd/heevd
|
|
|
|
|
2021-09-03 10:03:25 -07:00
|
|
|
// # Workspace sizes, taken from the LAPACK documentation.
|
|
|
|
lapack_int SyevdWorkSize(int64_t n) {
|
2024-06-13 05:43:40 -07:00
|
|
|
return CastNoOverflow<lapack_int>(1 + 6 * n + 2 * n * n, "syevd lwork");
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
lapack_int SyevdIworkSize(int64_t n) {
|
2024-06-13 05:43:40 -07:00
|
|
|
return CastNoOverflow<lapack_int>(3 + 5 * n, "syevd iwork");
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename RealSyevd<T>::FnType* RealSyevd<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void RealSyevd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[3]);
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
T* w_out = reinterpret_cast<T*>(out[1]);
|
|
|
|
int* info_out = reinterpret_cast<int*>(out[2]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[3]);
|
|
|
|
int* iwork = reinterpret_cast<int*>(out[4]);
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(n) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
char jobz = 'V';
|
|
|
|
char uplo = lower ? 'L' : 'U';
|
|
|
|
|
|
|
|
lapack_int lwork = SyevdWorkSize(n);
|
|
|
|
lapack_int liwork = SyevdIworkSize(n);
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork,
|
|
|
|
info_out);
|
|
|
|
a_out += static_cast<int64_t>(n) * n;
|
|
|
|
w_out += n;
|
|
|
|
++info_out;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Workspace sizes, taken from the LAPACK documentation.
|
|
|
|
lapack_int HeevdWorkSize(int64_t n) {
|
2024-06-13 05:43:40 -07:00
|
|
|
return CastNoOverflow<lapack_int>(1 + 2 * n + n * n, "heevd work");
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
lapack_int HeevdRworkSize(int64_t n) {
|
2024-06-13 05:43:40 -07:00
|
|
|
return CastNoOverflow<lapack_int>(1 + 5 * n + 2 * n * n, "heevd rwork");
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename ComplexHeevd<T>::FnType* ComplexHeevd<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void ComplexHeevd<T>::Kernel(void* out_tuple, void** data,
|
|
|
|
XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int32_t lower = *(reinterpret_cast<int32_t*>(data[0]));
|
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int n = *(reinterpret_cast<int32_t*>(data[2]));
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[3]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
typename T::value_type* w_out =
|
|
|
|
reinterpret_cast<typename T::value_type*>(out[1]);
|
|
|
|
int* info_out = reinterpret_cast<int*>(out[2]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[3]);
|
|
|
|
typename T::value_type* rwork =
|
|
|
|
reinterpret_cast<typename T::value_type*>(out[4]);
|
|
|
|
int* iwork = reinterpret_cast<int*>(out[5]);
|
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(n) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
char jobz = 'V';
|
|
|
|
char uplo = lower ? 'L' : 'U';
|
|
|
|
|
|
|
|
lapack_int lwork = HeevdWorkSize(n);
|
|
|
|
lapack_int lrwork = HeevdRworkSize(n);
|
|
|
|
lapack_int liwork = SyevdIworkSize(n);
|
|
|
|
for (int i = 0; i < b; ++i) {
|
|
|
|
fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork, iwork,
|
|
|
|
&liwork, info_out);
|
|
|
|
a_out += static_cast<int64_t>(n) * n;
|
|
|
|
w_out += n;
|
|
|
|
++info_out;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct RealSyevd<float>;
|
|
|
|
template struct RealSyevd<double>;
|
|
|
|
template struct ComplexHeevd<std::complex<float>>;
|
|
|
|
template struct ComplexHeevd<std::complex<double>>;
|
|
|
|
|
|
|
|
// LAPACK uses a packed representation to represent a mixture of real
|
|
|
|
// eigenvectors and complex conjugate pairs. This helper unpacks the
|
|
|
|
// representation into regular complex matrices.
|
|
|
|
template <typename T>
|
|
|
|
static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed,
|
|
|
|
std::complex<T>* unpacked) {
|
|
|
|
T re, im;
|
|
|
|
int j;
|
|
|
|
j = 0;
|
|
|
|
while (j < n) {
|
|
|
|
if (im_eigenvalues[j] == 0. || std::isnan(im_eigenvalues[j])) {
|
|
|
|
for (int k = 0; k < n; ++k) {
|
|
|
|
unpacked[j * n + k] = {packed[j * n + k], 0.};
|
|
|
|
}
|
|
|
|
++j;
|
|
|
|
} else {
|
|
|
|
for (int k = 0; k < n; ++k) {
|
|
|
|
re = packed[j * n + k];
|
|
|
|
im = packed[(j + 1) * n + k];
|
|
|
|
unpacked[j * n + k] = {re, im};
|
|
|
|
unpacked[(j + 1) * n + k] = {re, -im};
|
|
|
|
}
|
|
|
|
j += 2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
// lapack geev
|
|
|
|
|
2021-09-03 10:03:25 -07:00
|
|
|
template <typename T>
|
|
|
|
typename RealGeev<T>::FnType* RealGeev<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void RealGeev<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
2022-01-07 10:47:32 -08:00
|
|
|
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int64_t n = n_int;
|
2021-09-03 10:03:25 -07:00
|
|
|
char jobvl = *(reinterpret_cast<uint8_t*>(data[2]));
|
|
|
|
char jobvr = *(reinterpret_cast<uint8_t*>(data[3]));
|
|
|
|
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[4]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_work = reinterpret_cast<T*>(out[0]);
|
|
|
|
T* vl_work = reinterpret_cast<T*>(out[1]);
|
|
|
|
T* vr_work = reinterpret_cast<T*>(out[2]);
|
|
|
|
|
|
|
|
T* wr_out = reinterpret_cast<T*>(out[3]);
|
|
|
|
T* wi_out = reinterpret_cast<T*>(out[4]);
|
|
|
|
std::complex<T>* vl_out = reinterpret_cast<std::complex<T>*>(out[5]);
|
|
|
|
std::complex<T>* vr_out = reinterpret_cast<std::complex<T>*>(out[6]);
|
|
|
|
int* info_out = reinterpret_cast<int*>(out[7]);
|
|
|
|
|
|
|
|
// TODO(phawkins): preallocate workspace using XLA.
|
|
|
|
T work_query;
|
|
|
|
int lwork = -1;
|
2022-01-07 10:47:32 -08:00
|
|
|
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int,
|
|
|
|
vr_work, &n_int, &work_query, &lwork, info_out);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
2021-09-03 10:03:25 -07:00
|
|
|
lwork = static_cast<int>(work_query);
|
|
|
|
T* work = new T[lwork];
|
|
|
|
|
2023-11-20 17:27:42 -08:00
|
|
|
auto is_finite = [](T* a_work, int64_t n) {
|
|
|
|
for (int64_t j = 0; j < n; ++j) {
|
|
|
|
for (int64_t k = 0; k < n; ++k) {
|
|
|
|
if (!std::isfinite(a_work[j * n + k])) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
};
|
2021-09-03 10:03:25 -07:00
|
|
|
for (int i = 0; i < b; ++i) {
|
2022-01-07 10:47:32 -08:00
|
|
|
size_t a_size = n * n * sizeof(T);
|
|
|
|
std::memcpy(a_work, a_in, a_size);
|
2023-11-20 17:27:42 -08:00
|
|
|
if (is_finite(a_work, n)) {
|
|
|
|
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work,
|
|
|
|
&n_int, vr_work, &n_int, work, &lwork, info_out);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
|
|
|
if (info_out[0] == 0) {
|
|
|
|
UnpackEigenvectors(n, wi_out, vl_work, vl_out);
|
|
|
|
UnpackEigenvectors(n, wi_out, vr_work, vr_out);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
*info_out = -4;
|
2021-09-03 10:03:25 -07:00
|
|
|
}
|
2022-01-07 10:47:32 -08:00
|
|
|
a_in += n * n;
|
2021-09-03 10:03:25 -07:00
|
|
|
wr_out += n;
|
|
|
|
wi_out += n;
|
2022-01-07 10:47:32 -08:00
|
|
|
vl_out += n * n;
|
|
|
|
vr_out += n * n;
|
2021-09-03 10:03:25 -07:00
|
|
|
++info_out;
|
|
|
|
}
|
|
|
|
delete[] work;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename ComplexGeev<T>::FnType* ComplexGeev<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void ComplexGeev<T>::Kernel(void* out_tuple, void** data,
|
|
|
|
XlaCustomCallStatus*) {
|
2021-09-03 10:03:25 -07:00
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
2022-01-07 10:47:32 -08:00
|
|
|
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int64_t n = n_int;
|
2021-09-03 10:03:25 -07:00
|
|
|
char jobvl = *(reinterpret_cast<uint8_t*>(data[2]));
|
|
|
|
char jobvr = *(reinterpret_cast<uint8_t*>(data[3]));
|
|
|
|
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[4]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_work = reinterpret_cast<T*>(out[0]);
|
|
|
|
typename T::value_type* r_work =
|
|
|
|
reinterpret_cast<typename T::value_type*>(out[1]);
|
|
|
|
|
|
|
|
T* w_out = reinterpret_cast<T*>(out[2]);
|
|
|
|
T* vl_out = reinterpret_cast<T*>(out[3]);
|
|
|
|
T* vr_out = reinterpret_cast<T*>(out[4]);
|
|
|
|
int* info_out = reinterpret_cast<int*>(out[5]);
|
|
|
|
|
|
|
|
// TODO(phawkins): preallocate workspace using XLA.
|
|
|
|
T work_query;
|
|
|
|
int lwork = -1;
|
2022-01-07 10:47:32 -08:00
|
|
|
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
|
|
|
&n_int, &work_query, &lwork, r_work, info_out);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
2021-09-03 10:03:25 -07:00
|
|
|
lwork = static_cast<int>(work_query.real());
|
|
|
|
T* work = new T[lwork];
|
|
|
|
|
2023-11-20 17:27:42 -08:00
|
|
|
auto is_finite = [](T* a_work, int64_t n) {
|
|
|
|
for (int64_t j = 0; j < n; ++j) {
|
|
|
|
for (int64_t k = 0; k < n; ++k) {
|
|
|
|
T v = a_work[j * n + k];
|
|
|
|
if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
};
|
|
|
|
|
2021-09-03 10:03:25 -07:00
|
|
|
for (int i = 0; i < b; ++i) {
|
2022-01-07 10:47:32 -08:00
|
|
|
size_t a_size = n * n * sizeof(T);
|
|
|
|
std::memcpy(a_work, a_in, a_size);
|
2023-11-20 17:27:42 -08:00
|
|
|
if (is_finite(a_work, n)) {
|
|
|
|
fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out,
|
|
|
|
&n_int, work, &lwork, r_work, info_out);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
|
|
|
} else {
|
|
|
|
*info_out = -4;
|
|
|
|
}
|
2022-01-07 10:47:32 -08:00
|
|
|
a_in += n * n;
|
2021-09-03 10:03:25 -07:00
|
|
|
w_out += n;
|
2022-01-07 10:47:32 -08:00
|
|
|
vl_out += n * n;
|
|
|
|
vr_out += n * n;
|
2021-09-03 10:03:25 -07:00
|
|
|
info_out += 1;
|
|
|
|
}
|
|
|
|
delete[] work;
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct RealGeev<float>;
|
|
|
|
template struct RealGeev<double>;
|
|
|
|
template struct ComplexGeev<std::complex<float>>;
|
|
|
|
template struct ComplexGeev<std::complex<double>>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== Schur Decomposition ==//
|
|
|
|
|
|
|
|
// lapack gees
|
2021-09-28 20:34:35 +02:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename RealGees<T>::FnType* RealGees<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void RealGees<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
2021-09-28 20:34:35 +02:00
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
2022-01-07 10:47:32 -08:00
|
|
|
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int64_t n = n_int;
|
2021-09-28 20:34:35 +02:00
|
|
|
char jobvs = *(reinterpret_cast<uint8_t*>(data[2]));
|
|
|
|
char sort = *(reinterpret_cast<uint8_t*>(data[3]));
|
|
|
|
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[4]);
|
|
|
|
|
|
|
|
// bool* select (T, T) = reinterpret_cast<bool* (T, T)>(data[5]);
|
2022-08-30 14:03:54 -07:00
|
|
|
bool (*select)(T, T) = nullptr;
|
2021-09-28 20:34:35 +02:00
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
2022-10-10 12:53:19 -07:00
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
2021-09-28 20:34:35 +02:00
|
|
|
|
|
|
|
T* wr_out = reinterpret_cast<T*>(out[1]);
|
|
|
|
T* wi_out = reinterpret_cast<T*>(out[2]);
|
|
|
|
T* vs_out = reinterpret_cast<T*>(out[3]);
|
|
|
|
int* sdim_out = reinterpret_cast<int*>(out[4]);
|
|
|
|
int* info_out = reinterpret_cast<int*>(out[5]);
|
|
|
|
|
2022-08-30 14:03:54 -07:00
|
|
|
bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr;
|
2021-09-28 20:34:35 +02:00
|
|
|
|
|
|
|
T work_query;
|
|
|
|
int lwork = -1;
|
2022-10-10 12:53:19 -07:00
|
|
|
fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out,
|
2022-01-07 10:47:32 -08:00
|
|
|
vs_out, &n_int, &work_query, &lwork, b_work, info_out);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
2021-09-28 20:34:35 +02:00
|
|
|
lwork = static_cast<int>(work_query);
|
|
|
|
T* work = new T[lwork];
|
|
|
|
|
2022-10-10 14:24:17 -07:00
|
|
|
size_t a_size = static_cast<int64_t>(n) * static_cast<int64_t>(n) * sizeof(T);
|
2022-10-10 12:53:19 -07:00
|
|
|
if (a_out != a_in) {
|
2022-10-10 14:24:17 -07:00
|
|
|
std::memcpy(a_out, a_in, static_cast<int64_t>(b) * a_size);
|
2022-10-10 12:53:19 -07:00
|
|
|
}
|
|
|
|
|
2021-09-28 20:34:35 +02:00
|
|
|
for (int i = 0; i < b; ++i) {
|
2022-10-10 12:53:19 -07:00
|
|
|
fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out,
|
2022-01-07 10:47:32 -08:00
|
|
|
vs_out, &n_int, work, &lwork, b_work, info_out);
|
2022-10-10 12:53:19 -07:00
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_out, a_size);
|
2022-01-07 10:47:32 -08:00
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int));
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
|
|
|
|
|
|
|
a_in += n * n;
|
2022-10-10 12:53:19 -07:00
|
|
|
a_out += n * n;
|
2021-09-28 20:34:35 +02:00
|
|
|
wr_out += n;
|
|
|
|
wi_out += n;
|
2022-01-07 10:47:32 -08:00
|
|
|
vs_out += n * n;
|
2021-09-28 20:34:35 +02:00
|
|
|
++sdim_out;
|
|
|
|
++info_out;
|
|
|
|
}
|
|
|
|
delete[] work;
|
|
|
|
delete[] b_work;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
typename ComplexGees<T>::FnType* ComplexGees<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
2022-03-29 17:13:45 -07:00
|
|
|
void ComplexGees<T>::Kernel(void* out_tuple, void** data,
|
|
|
|
XlaCustomCallStatus*) {
|
2021-09-28 20:34:35 +02:00
|
|
|
int b = *(reinterpret_cast<int32_t*>(data[0]));
|
2022-01-07 10:47:32 -08:00
|
|
|
int n_int = *(reinterpret_cast<int32_t*>(data[1]));
|
|
|
|
int64_t n = n_int;
|
2021-09-28 20:34:35 +02:00
|
|
|
char jobvs = *(reinterpret_cast<uint8_t*>(data[2]));
|
|
|
|
char sort = *(reinterpret_cast<uint8_t*>(data[3]));
|
|
|
|
|
|
|
|
const T* a_in = reinterpret_cast<T*>(data[4]);
|
|
|
|
|
|
|
|
// bool* select (T, T) = reinterpret_cast<bool* (T, T)>(data[5]);
|
2022-08-30 14:03:54 -07:00
|
|
|
bool (*select)(T) = nullptr;
|
2021-09-28 20:34:35 +02:00
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
2022-10-10 12:53:19 -07:00
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
2021-09-28 20:34:35 +02:00
|
|
|
typename T::value_type* r_work =
|
|
|
|
reinterpret_cast<typename T::value_type*>(out[1]);
|
|
|
|
T* w_out = reinterpret_cast<T*>(out[2]);
|
|
|
|
T* vs_out = reinterpret_cast<T*>(out[3]);
|
|
|
|
int* sdim_out = reinterpret_cast<int*>(out[4]);
|
|
|
|
int* info_out = reinterpret_cast<int*>(out[5]);
|
|
|
|
|
2022-08-30 14:03:54 -07:00
|
|
|
bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr;
|
2021-09-28 20:34:35 +02:00
|
|
|
|
|
|
|
T work_query;
|
|
|
|
int lwork = -1;
|
2022-10-10 12:53:19 -07:00
|
|
|
fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out,
|
2022-01-07 10:47:32 -08:00
|
|
|
&n_int, &work_query, &lwork, r_work, b_work, info_out);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query));
|
2021-09-28 20:34:35 +02:00
|
|
|
lwork = static_cast<int>(work_query.real());
|
|
|
|
T* work = new T[lwork];
|
|
|
|
|
2022-10-10 12:53:19 -07:00
|
|
|
if (a_out != a_in) {
|
|
|
|
std::memcpy(a_out, a_in,
|
|
|
|
static_cast<int64_t>(b) * static_cast<int64_t>(n) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
2021-09-28 20:34:35 +02:00
|
|
|
for (int i = 0; i < b; ++i) {
|
2022-10-10 12:53:19 -07:00
|
|
|
fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out,
|
2022-01-07 10:47:32 -08:00
|
|
|
&n_int, work, &lwork, r_work, b_work, info_out);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n);
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int));
|
|
|
|
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int));
|
|
|
|
|
|
|
|
a_in += n * n;
|
2022-10-10 12:53:19 -07:00
|
|
|
a_out += n * n;
|
2021-09-28 20:34:35 +02:00
|
|
|
w_out += n;
|
2022-01-07 10:47:32 -08:00
|
|
|
vs_out += n * n;
|
2021-09-28 20:34:35 +02:00
|
|
|
++info_out;
|
|
|
|
++sdim_out;
|
|
|
|
}
|
|
|
|
delete[] work;
|
|
|
|
delete[] b_work;
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct RealGees<float>;
|
|
|
|
template struct RealGees<double>;
|
|
|
|
template struct ComplexGees<std::complex<float>>;
|
|
|
|
template struct ComplexGees<std::complex<double>>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== Hessenberg Decomposition ==//
|
|
|
|
|
|
|
|
// lapack gehrd
|
|
|
|
|
2022-11-09 06:23:22 -08:00
|
|
|
template <typename T>
|
|
|
|
typename Gehrd<T>::FnType* Gehrd<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void Gehrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
|
|
|
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
|
|
|
|
int32_t ilo = *reinterpret_cast<int32_t*>(data[1]);
|
|
|
|
int32_t ihi = *reinterpret_cast<int32_t*>(data[2]);
|
|
|
|
int32_t lda = *reinterpret_cast<int32_t*>(data[3]);
|
|
|
|
int32_t batch = *reinterpret_cast<int32_t*>(data[4]);
|
|
|
|
int32_t lwork = *reinterpret_cast<int32_t*>(data[5]);
|
|
|
|
T* a = reinterpret_cast<T*>(data[6]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
T* tau = reinterpret_cast<T*>(out[1]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[2]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[3]);
|
|
|
|
|
|
|
|
if (a_out != a) {
|
|
|
|
std::memcpy(a_out, a,
|
|
|
|
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
|
|
|
|
|
|
|
|
for (int i = 0; i < batch; ++i) {
|
|
|
|
fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info);
|
|
|
|
a_out += a_plus;
|
|
|
|
tau += n - 1;
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
int64_t Gehrd<T>::Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
|
|
|
|
lapack_int ihi) {
|
|
|
|
T work = 0;
|
|
|
|
lapack_int lwork = -1;
|
|
|
|
lapack_int info = 0;
|
|
|
|
fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info);
|
|
|
|
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct Gehrd<float>;
|
|
|
|
template struct Gehrd<double>;
|
|
|
|
template struct Gehrd<std::complex<float>>;
|
|
|
|
template struct Gehrd<std::complex<double>>;
|
|
|
|
|
2024-06-13 05:43:40 -07:00
|
|
|
//== Tridiagonal Reduction ==//
|
|
|
|
|
|
|
|
// lapack sytrd/hetrd
|
|
|
|
|
2022-11-09 06:23:22 -08:00
|
|
|
template <typename T>
|
|
|
|
typename Sytrd<T>::FnType* Sytrd<T>::fn = nullptr;
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void Sytrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
|
|
|
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
|
|
|
|
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
|
|
|
|
int32_t lda = *reinterpret_cast<int32_t*>(data[2]);
|
|
|
|
int32_t batch = *reinterpret_cast<int32_t*>(data[3]);
|
|
|
|
int32_t lwork = *reinterpret_cast<int32_t*>(data[4]);
|
|
|
|
T* a = reinterpret_cast<T*>(data[5]);
|
|
|
|
|
|
|
|
void** out = reinterpret_cast<void**>(out_tuple);
|
|
|
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
|
|
|
typedef typename real_type<T>::type Real;
|
|
|
|
Real* d = reinterpret_cast<Real*>(out[1]);
|
|
|
|
Real* e = reinterpret_cast<Real*>(out[2]);
|
|
|
|
T* tau = reinterpret_cast<T*>(out[3]);
|
|
|
|
int* info = reinterpret_cast<int*>(out[4]);
|
|
|
|
T* work = reinterpret_cast<T*>(out[5]);
|
|
|
|
|
|
|
|
if (a_out != a) {
|
|
|
|
std::memcpy(a_out, a,
|
|
|
|
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
|
|
|
|
static_cast<int64_t>(n) * sizeof(T));
|
|
|
|
}
|
|
|
|
|
|
|
|
char cuplo = lower ? 'L' : 'U';
|
|
|
|
|
|
|
|
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
|
|
|
|
|
|
|
|
for (int i = 0; i < batch; ++i) {
|
|
|
|
fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info);
|
|
|
|
a_out += a_plus;
|
|
|
|
d += n;
|
|
|
|
e += n - 1;
|
|
|
|
tau += n - 1;
|
|
|
|
++info;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
int64_t Sytrd<T>::Workspace(lapack_int lda, lapack_int n) {
|
|
|
|
char cuplo = 'L';
|
|
|
|
T work = 0;
|
|
|
|
lapack_int lwork = -1;
|
|
|
|
lapack_int info = 0;
|
|
|
|
fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork,
|
|
|
|
&info);
|
|
|
|
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
template struct Sytrd<float>;
|
|
|
|
template struct Sytrd<double>;
|
|
|
|
template struct Sytrd<std::complex<float>>;
|
|
|
|
template struct Sytrd<std::complex<double>>;
|
|
|
|
|
2021-09-03 10:03:25 -07:00
|
|
|
} // namespace jax
|
2024-06-27 09:40:33 -07:00
|
|
|
|
|
|
|
#undef ASSIGN_OR_RETURN_FFI_ERROR
|
|
|
|
#undef RETURN_IF_FFI_ERROR
|