mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Determine LAPACK workspace during Eigenvalue Kernels runtime
PiperOrigin-RevId: 666285759
This commit is contained in:
parent
a72d46c549
commit
4786930a4c
@ -49,11 +49,7 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr
|
||||
|
||||
|
||||
# FFI Kernel LAPACK Workspace Size Queries
|
||||
def heevd_rwork_size_ffi(n: int) -> int: ...
|
||||
def heevd_work_size_ffi(n: int) -> int: ...
|
||||
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def syevd_iwork_size_ffi(n: int) -> int: ...
|
||||
def syevd_work_size_ffi(n: int) -> int: ...
|
||||
|
@ -37,14 +37,6 @@ svd::ComputationMode GetSvdComputationMode(bool job_opt_compute_uv,
|
||||
return svd::ComputationMode::kComputeFullUVt;
|
||||
}
|
||||
|
||||
// Due to enforced kComputeEigenvectors, this assumes a larger workspace size.
|
||||
// Could be improved to more accurately estimate the expected size based on the
|
||||
// eig::ComputationMode value.
|
||||
template <lapack_int (&f)(int64_t, eig::ComputationMode)>
|
||||
inline constexpr auto BoundWithEigvecs = +[](lapack_int n) {
|
||||
return f(n, eig::ComputationMode::kComputeEigenvectors);
|
||||
};
|
||||
|
||||
void GetLapackKernelsFromScipy() {
|
||||
static bool initialized = false; // Protected by GIL
|
||||
if (initialized) return;
|
||||
@ -348,14 +340,6 @@ NB_MODULE(_lapack, m) {
|
||||
m.def("lapack_zungqr_workspace_ffi",
|
||||
&OrthogonalQr<DataType::C128>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
m.def("syevd_work_size_ffi", BoundWithEigvecs<eig::GetWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
m.def("syevd_iwork_size_ffi", BoundWithEigvecs<eig::GetIntWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
m.def("heevd_work_size_ffi", BoundWithEigvecs<eig::GetComplexWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
m.def("heevd_rwork_size_ffi", BoundWithEigvecs<eig::GetRealWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -954,21 +954,24 @@ template struct ComplexHeevd<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
absl::StatusOr<lapack_int> eig::GetWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(2 * x_cols + 1);
|
||||
return MaybeCastNoOverflow<lapack_int>(2 * x_cols + 1);
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(1 + 6 * x_cols + 2 * x_cols * x_cols);
|
||||
return MaybeCastNoOverflow<lapack_int>(1 + 6 * x_cols +
|
||||
2 * x_cols * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
absl::StatusOr<lapack_int> eig::GetIntWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return 1;
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(3 + 5 * x_cols);
|
||||
return MaybeCastNoOverflow<lapack_int>(3 + 5 * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
@ -976,34 +979,34 @@ template <ffi::DataType dtype>
|
||||
ffi::Error EigenvalueDecompositionSymmetric<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> eigenvalues,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
|
||||
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
|
||||
ffi::ResultBuffer<LapackIntDtype> info, eig::ComputationMode mode) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* eigenvalues_data = eigenvalues->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
auto* work_data = work->typed_data();
|
||||
auto* iwork_data = iwork->typed_data();
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
work->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
iwork->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
// Prepare LAPACK workspaces.
|
||||
FFI_ASSIGN_OR_RETURN(lapack_int work_size_v,
|
||||
eig::GetWorkspaceSize(x_cols, mode));
|
||||
FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v,
|
||||
eig::GetIntWorkspaceSize(x_cols, mode));
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size_v);
|
||||
auto iwork_data = AllocateScratchMemory<LapackIntDtype>(iwork_size_v);
|
||||
|
||||
const int64_t x_out_step{x_cols * x_cols};
|
||||
const int64_t eigenvalues_step{x_cols};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
|
||||
eigenvalues_data, work_data, &workspace_dim_v, iwork_data,
|
||||
&iworkspace_dim_v, info_data);
|
||||
eigenvalues_data, work_data.get(), &work_size_v, iwork_data.get(),
|
||||
&iwork_size_v, info_data);
|
||||
x_out_data += x_out_step;
|
||||
eigenvalues_data += eigenvalues_step;
|
||||
++info_data;
|
||||
@ -1013,21 +1016,24 @@ ffi::Error EigenvalueDecompositionSymmetric<dtype>::Kernel(
|
||||
|
||||
namespace eig {
|
||||
|
||||
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
absl::StatusOr<lapack_int> GetComplexWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(x_cols + 1);
|
||||
return MaybeCastNoOverflow<lapack_int>(x_cols + 1);
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(2 * x_cols + x_cols * x_cols);
|
||||
return MaybeCastNoOverflow<lapack_int>(2 * x_cols + x_cols * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
absl::StatusOr<lapack_int> GetRealWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(std::max(x_cols, int64_t{1}));
|
||||
return MaybeCastNoOverflow<lapack_int>(std::max(x_cols, int64_t{1}));
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(1 + 5 * x_cols + 2 * x_cols * x_cols);
|
||||
return MaybeCastNoOverflow<lapack_int>(1 + 5 * x_cols +
|
||||
2 * x_cols * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1038,37 +1044,37 @@ ffi::Error EigenvalueDecompositionHermitian<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> eigenvalues,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork,
|
||||
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
|
||||
ffi::ResultBuffer<LapackIntDtype> info, eig::ComputationMode mode) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* eigenvalues_data = eigenvalues->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
auto* work_data = work->typed_data();
|
||||
auto* iwork_data = iwork->typed_data();
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
work->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
rwork->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
iwork->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
// Prepare LAPACK workspaces.
|
||||
FFI_ASSIGN_OR_RETURN(lapack_int work_size_v,
|
||||
eig::GetComplexWorkspaceSize(x_cols, mode));
|
||||
FFI_ASSIGN_OR_RETURN(lapack_int rwork_size_v,
|
||||
eig::GetRealWorkspaceSize(x_cols, mode));
|
||||
FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v,
|
||||
eig::GetIntWorkspaceSize(x_cols, mode));
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size_v);
|
||||
auto iwork_data = AllocateScratchMemory<LapackIntDtype>(iwork_size_v);
|
||||
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(rwork_size_v);
|
||||
|
||||
const int64_t x_out_step{x_cols * x_cols};
|
||||
const int64_t eigenvalues_step{x_cols};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
|
||||
eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(),
|
||||
&rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data);
|
||||
eigenvalues_data, work_data.get(), &work_size_v, rwork_data.get(),
|
||||
&rwork_size_v, iwork_data.get(), &iwork_size_v, info_data);
|
||||
x_out_data += x_out_step;
|
||||
eigenvalues_data += eigenvalues_step;
|
||||
++info_data;
|
||||
@ -1265,16 +1271,11 @@ ffi::Error EigenvalueDecomposition<dtype>::Kernel(
|
||||
ffi::ResultBuffer<dtype> eigvals_imag,
|
||||
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_left,
|
||||
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_right,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_left,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_right) {
|
||||
ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
|
||||
const auto* x_data = x.typed_data();
|
||||
auto* x_work_data = x_work->typed_data();
|
||||
auto* work_eigvecs_left_data = work_eigvecs_left->typed_data();
|
||||
auto* work_eigvecs_right_data = work_eigvecs_right->typed_data();
|
||||
auto* eigvecs_left_data = eigvecs_left->typed_data();
|
||||
auto* eigvecs_right_data = eigvecs_right->typed_data();
|
||||
auto* eigvals_real_data = eigvals_real->typed_data();
|
||||
@ -1284,43 +1285,45 @@ ffi::Error EigenvalueDecomposition<dtype>::Kernel(
|
||||
auto compute_left_v = static_cast<char>(compute_left);
|
||||
auto compute_right_v = static_cast<char>(compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
// Prepare LAPACK workspaces.
|
||||
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto work_size_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
// TODO(phawkins): preallocate workspace using XLA.
|
||||
auto work = std::make_unique<ValueType[]>(work_size);
|
||||
auto* work_data = work.get();
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size);
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
auto x_copy = AllocateScratchMemory<dtype>(x_size);
|
||||
auto work_eigvecs_left = AllocateScratchMemory<dtype>(x_size);
|
||||
auto work_eigvecs_right = AllocateScratchMemory<dtype>(x_size);
|
||||
|
||||
const auto is_finite = [](ValueType* data, int64_t size) {
|
||||
return absl::c_all_of(absl::MakeSpan(data, size),
|
||||
[](ValueType value) { return std::isfinite(value); });
|
||||
};
|
||||
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
[[maybe_unused]] const auto x_size_bytes =
|
||||
static_cast<unsigned long>(x_size) * sizeof(ValueType);
|
||||
[[maybe_unused]] const auto x_cols_bytes =
|
||||
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
std::copy_n(x_data, x_size, x_work_data);
|
||||
if (is_finite(x_work_data, x_size)) {
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
|
||||
eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data,
|
||||
&x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v,
|
||||
info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
|
||||
std::copy_n(x_data, x_size, x_copy.get());
|
||||
if (is_finite(x_copy.get(), x_size)) {
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v,
|
||||
eigvals_real_data, eigvals_imag_data, work_eigvecs_left.get(),
|
||||
&x_cols_v, work_eigvecs_right.get(), &x_cols_v, work_data.get(),
|
||||
&work_size_v, info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data,
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left.get(),
|
||||
x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right.get(),
|
||||
x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
|
||||
if (info_data[0] == 0) {
|
||||
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data,
|
||||
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left.get(),
|
||||
eigvecs_left_data);
|
||||
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data,
|
||||
eigvecs_right_data);
|
||||
UnpackEigenvectors(x_cols_v, eigvals_imag_data,
|
||||
work_eigvecs_right.get(), eigvecs_right_data);
|
||||
}
|
||||
} else {
|
||||
info_data[0] = -4;
|
||||
@ -1341,12 +1344,10 @@ ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
|
||||
eig::ComputationMode compute_right, ffi::ResultBuffer<dtype> eigvals,
|
||||
ffi::ResultBuffer<dtype> eigvecs_left,
|
||||
ffi::ResultBuffer<dtype> eigvecs_right,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork) {
|
||||
ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
const auto* x_data = x.typed_data();
|
||||
auto* x_work_data = x_work->typed_data();
|
||||
auto* eigvecs_left_data = eigvecs_left->typed_data();
|
||||
auto* eigvecs_right_data = eigvecs_right->typed_data();
|
||||
auto* eigvals_data = eigvals->typed_data();
|
||||
@ -1355,13 +1356,14 @@ ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
|
||||
auto compute_left_v = static_cast<char>(compute_left);
|
||||
auto compute_right_v = static_cast<char>(compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
// Prepare LAPACK workspaces.
|
||||
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto work_size_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
// TODO(phawkins): preallocate workspace using XLA.
|
||||
auto work = std::make_unique<ValueType[]>(work_size);
|
||||
auto* work_data = work.get();
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size);
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
auto x_copy = AllocateScratchMemory<dtype>(x_size);
|
||||
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(2 * x_cols);
|
||||
|
||||
const auto is_finite = [](ValueType* data, int64_t size) {
|
||||
return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) {
|
||||
@ -1369,18 +1371,18 @@ ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
|
||||
});
|
||||
};
|
||||
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
[[maybe_unused]] const auto x_size_bytes =
|
||||
static_cast<unsigned long>(x_size) * sizeof(ValueType);
|
||||
[[maybe_unused]] const auto x_cols_bytes =
|
||||
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
std::copy_n(x_data, x_size, x_work_data);
|
||||
if (is_finite(x_work_data, x_size)) {
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
|
||||
std::copy_n(x_data, x_size, x_copy.get());
|
||||
if (is_finite(x_copy.get(), x_size)) {
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v,
|
||||
eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data,
|
||||
&x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
|
||||
&x_cols_v, work_data.get(), &work_size_v, rwork_data.get(),
|
||||
info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes);
|
||||
@ -1766,8 +1768,6 @@ template struct Sytrd<std::complex<double>>;
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigenvalues*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
|
||||
.Attr<eig::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_HEEVD(name, data_type) \
|
||||
@ -1780,9 +1780,6 @@ template struct Sytrd<std::complex<double>>;
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*eigenvalues*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
|
||||
.Attr<eig::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEEV(name, data_type) \
|
||||
@ -1798,12 +1795,7 @@ template struct Sytrd<std::complex<double>>;
|
||||
/*eigvecs_left*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \
|
||||
/*eigvecs_right*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*work_eigvecs_left*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*work_eigvecs_right*/))
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
@ -1815,9 +1807,7 @@ template struct Sytrd<std::complex<double>>;
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_left*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/))
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
// FFI Handlers
|
||||
|
||||
|
@ -395,12 +395,16 @@ struct ComplexHeevd {
|
||||
namespace eig {
|
||||
|
||||
// Eigenvalue Decomposition
|
||||
lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
absl::StatusOr<lapack_int> GetWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode);
|
||||
absl::StatusOr<lapack_int> GetIntWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode);
|
||||
|
||||
// Hermitian Eigenvalue Decomposition
|
||||
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
absl::StatusOr<lapack_int> GetComplexWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode);
|
||||
absl::StatusOr<lapack_int> GetRealWorkspaceSize(int64_t x_cols,
|
||||
ComputationMode mode);
|
||||
|
||||
} // namespace eig
|
||||
|
||||
@ -417,13 +421,11 @@ struct EigenvalueDecompositionSymmetric {
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
static ::xla::ffi::Error Kernel(::xla::ffi::Buffer<dtype> x,
|
||||
MatrixParams::UpLo uplo,
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<dtype> eigenvalues,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> work,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
|
||||
eig::ComputationMode mode);
|
||||
};
|
||||
|
||||
@ -445,9 +447,6 @@ struct EigenvalueDecompositionHermitian {
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> work,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
|
||||
eig::ComputationMode mode);
|
||||
};
|
||||
|
||||
@ -496,10 +495,7 @@ struct EigenvalueDecomposition {
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals_imag,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> x_work,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right);
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_cols,
|
||||
eig::ComputationMode compute_left,
|
||||
@ -526,9 +522,7 @@ struct EigenvalueDecompositionComplex {
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvecs_left,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvecs_right,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> x_work,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork);
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_cols,
|
||||
eig::ComputationMode compute_left,
|
||||
|
Loading…
x
Reference in New Issue
Block a user