Refactor kernel function assigment

PiperOrigin-RevId: 641255192
This commit is contained in:
Paweł Paruzel 2024-06-07 08:17:58 -07:00 committed by jax authors
parent f51af87fc5
commit 5fcd50b7fa
3 changed files with 118 additions and 150 deletions

View File

@ -35,12 +35,11 @@ void GetLapackKernelsFromScipy() {
auto blas_ptr = [&](const char* name) {
return nb::cast<nb::capsule>(blas_capi[name]).data();
};
Trsm<float>::fn = reinterpret_cast<Trsm<float>::FnType*>(blas_ptr("strsm"));
Trsm<double>::fn = reinterpret_cast<Trsm<double>::FnType*>(blas_ptr("dtrsm"));
Trsm<std::complex<float>>::fn =
reinterpret_cast<Trsm<std::complex<float>>::FnType*>(blas_ptr("ctrsm"));
Trsm<std::complex<double>>::fn =
reinterpret_cast<Trsm<std::complex<double>>::FnType*>(blas_ptr("ztrsm"));
AssignKernelFn<Trsm<float>>(blas_ptr("strsm"));
AssignKernelFn<Trsm<double>>(blas_ptr("dtrsm"));
AssignKernelFn<Trsm<std::complex<float>>>(blas_ptr("ctrsm"));
AssignKernelFn<Trsm<std::complex<double>>>(blas_ptr("ztrsm"));
nb::module_ cython_lapack =
nb::module_::import_("scipy.linalg.cython_lapack");
@ -48,106 +47,55 @@ void GetLapackKernelsFromScipy() {
auto lapack_ptr = [&](const char* name) {
return nb::cast<nb::capsule>(lapack_capi[name]).data();
};
Getrf<float>::fn =
reinterpret_cast<Getrf<float>::FnType*>(lapack_ptr("sgetrf"));
Getrf<double>::fn =
reinterpret_cast<Getrf<double>::FnType*>(lapack_ptr("dgetrf"));
Getrf<std::complex<float>>::fn =
reinterpret_cast<Getrf<std::complex<float>>::FnType*>(
lapack_ptr("cgetrf"));
Getrf<std::complex<double>>::fn =
reinterpret_cast<Getrf<std::complex<double>>::FnType*>(
lapack_ptr("zgetrf"));
Geqrf<float>::fn =
reinterpret_cast<Geqrf<float>::FnType*>(lapack_ptr("sgeqrf"));
Geqrf<double>::fn =
reinterpret_cast<Geqrf<double>::FnType*>(lapack_ptr("dgeqrf"));
Geqrf<std::complex<float>>::fn =
reinterpret_cast<Geqrf<std::complex<float>>::FnType*>(
lapack_ptr("cgeqrf"));
Geqrf<std::complex<double>>::fn =
reinterpret_cast<Geqrf<std::complex<double>>::FnType*>(
lapack_ptr("zgeqrf"));
Orgqr<float>::fn =
reinterpret_cast<Orgqr<float>::FnType*>(lapack_ptr("sorgqr"));
Orgqr<double>::fn =
reinterpret_cast<Orgqr<double>::FnType*>(lapack_ptr("dorgqr"));
Orgqr<std::complex<float>>::fn =
reinterpret_cast<Orgqr<std::complex<float>>::FnType*>(
lapack_ptr("cungqr"));
Orgqr<std::complex<double>>::fn =
reinterpret_cast<Orgqr<std::complex<double>>::FnType*>(
lapack_ptr("zungqr"));
Potrf<float>::fn =
reinterpret_cast<Potrf<float>::FnType*>(lapack_ptr("spotrf"));
Potrf<double>::fn =
reinterpret_cast<Potrf<double>::FnType*>(lapack_ptr("dpotrf"));
Potrf<std::complex<float>>::fn =
reinterpret_cast<Potrf<std::complex<float>>::FnType*>(
lapack_ptr("cpotrf"));
Potrf<std::complex<double>>::fn =
reinterpret_cast<Potrf<std::complex<double>>::FnType*>(
lapack_ptr("zpotrf"));
RealGesdd<float>::fn =
reinterpret_cast<RealGesdd<float>::FnType*>(lapack_ptr("sgesdd"));
RealGesdd<double>::fn =
reinterpret_cast<RealGesdd<double>::FnType*>(lapack_ptr("dgesdd"));
ComplexGesdd<std::complex<float>>::fn =
reinterpret_cast<ComplexGesdd<std::complex<float>>::FnType*>(
lapack_ptr("cgesdd"));
ComplexGesdd<std::complex<double>>::fn =
reinterpret_cast<ComplexGesdd<std::complex<double>>::FnType*>(
lapack_ptr("zgesdd"));
RealSyevd<float>::fn =
reinterpret_cast<RealSyevd<float>::FnType*>(lapack_ptr("ssyevd"));
RealSyevd<double>::fn =
reinterpret_cast<RealSyevd<double>::FnType*>(lapack_ptr("dsyevd"));
ComplexHeevd<std::complex<float>>::fn =
reinterpret_cast<ComplexHeevd<std::complex<float>>::FnType*>(
lapack_ptr("cheevd"));
ComplexHeevd<std::complex<double>>::fn =
reinterpret_cast<ComplexHeevd<std::complex<double>>::FnType*>(
lapack_ptr("zheevd"));
RealGeev<float>::fn =
reinterpret_cast<RealGeev<float>::FnType*>(lapack_ptr("sgeev"));
RealGeev<double>::fn =
reinterpret_cast<RealGeev<double>::FnType*>(lapack_ptr("dgeev"));
ComplexGeev<std::complex<float>>::fn =
reinterpret_cast<ComplexGeev<std::complex<float>>::FnType*>(
lapack_ptr("cgeev"));
ComplexGeev<std::complex<double>>::fn =
reinterpret_cast<ComplexGeev<std::complex<double>>::FnType*>(
lapack_ptr("zgeev"));
RealGees<float>::fn =
reinterpret_cast<RealGees<float>::FnType*>(lapack_ptr("sgees"));
RealGees<double>::fn =
reinterpret_cast<RealGees<double>::FnType*>(lapack_ptr("dgees"));
ComplexGees<std::complex<float>>::fn =
reinterpret_cast<ComplexGees<std::complex<float>>::FnType*>(
lapack_ptr("cgees"));
ComplexGees<std::complex<double>>::fn =
reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>(
lapack_ptr("zgees"));
Gehrd<float>::fn =
reinterpret_cast<Gehrd<float>::FnType*>(lapack_ptr("sgehrd"));
Gehrd<double>::fn =
reinterpret_cast<Gehrd<double>::FnType*>(lapack_ptr("dgehrd"));
Gehrd<std::complex<float>>::fn =
reinterpret_cast<Gehrd<std::complex<float>>::FnType*>(
lapack_ptr("cgehrd"));
Gehrd<std::complex<double>>::fn =
reinterpret_cast<Gehrd<std::complex<double>>::FnType*>(
lapack_ptr("zgehrd"));
Sytrd<float>::fn =
reinterpret_cast<Sytrd<float>::FnType*>(lapack_ptr("ssytrd"));
Sytrd<double>::fn =
reinterpret_cast<Sytrd<double>::FnType*>(lapack_ptr("dsytrd"));
Sytrd<std::complex<float>>::fn =
reinterpret_cast<Sytrd<std::complex<float>>::FnType*>(
lapack_ptr("chetrd"));
Sytrd<std::complex<double>>::fn =
reinterpret_cast<Sytrd<std::complex<double>>::FnType*>(
lapack_ptr("zhetrd"));
AssignKernelFn<Getrf<float>>(lapack_ptr("sgetrf"));
AssignKernelFn<Getrf<double>>(lapack_ptr("dgetrf"));
AssignKernelFn<Getrf<std::complex<float>>>(lapack_ptr("cgetrf"));
AssignKernelFn<Getrf<std::complex<double>>>(lapack_ptr("zgetrf"));
AssignKernelFn<Geqrf<float>>(lapack_ptr("sgeqrf"));
AssignKernelFn<Geqrf<double>>(lapack_ptr("dgeqrf"));
AssignKernelFn<Geqrf<std::complex<float>>>(lapack_ptr("cgeqrf"));
AssignKernelFn<Geqrf<std::complex<double>>>(lapack_ptr("zgeqrf"));
AssignKernelFn<Orgqr<float>>(lapack_ptr("sorgqr"));
AssignKernelFn<Orgqr<double>>(lapack_ptr("dorgqr"));
AssignKernelFn<Orgqr<std::complex<float>>>(lapack_ptr("cungqr"));
AssignKernelFn<Orgqr<std::complex<double>>>(lapack_ptr("zungqr"));
AssignKernelFn<Potrf<float>>(lapack_ptr("spotrf"));
AssignKernelFn<Potrf<double>>(lapack_ptr("dpotrf"));
AssignKernelFn<Potrf<std::complex<float>>>(lapack_ptr("cpotrf"));
AssignKernelFn<Potrf<std::complex<double>>>(lapack_ptr("zpotrf"));
AssignKernelFn<RealGesdd<float>>(lapack_ptr("sgesdd"));
AssignKernelFn<RealGesdd<double>>(lapack_ptr("dgesdd"));
AssignKernelFn<ComplexGesdd<std::complex<float>>>(lapack_ptr("cgesdd"));
AssignKernelFn<ComplexGesdd<std::complex<double>>>(lapack_ptr("zgesdd"));
AssignKernelFn<RealSyevd<float>>(lapack_ptr("ssyevd"));
AssignKernelFn<RealSyevd<double>>(lapack_ptr("dsyevd"));
AssignKernelFn<ComplexHeevd<std::complex<float>>>(lapack_ptr("cheevd"));
AssignKernelFn<ComplexHeevd<std::complex<double>>>(lapack_ptr("zheevd"));
AssignKernelFn<RealGeev<float>>(lapack_ptr("sgeev"));
AssignKernelFn<RealGeev<double>>(lapack_ptr("dgeev"));
AssignKernelFn<ComplexGeev<std::complex<float>>>(lapack_ptr("cgeev"));
AssignKernelFn<ComplexGeev<std::complex<double>>>(lapack_ptr("zgeev"));
AssignKernelFn<RealGees<float>>(lapack_ptr("sgees"));
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
AssignKernelFn<ComplexGees<std::complex<float>>>(lapack_ptr("cgees"));
AssignKernelFn<ComplexGees<std::complex<double>>>(lapack_ptr("zgees"));
AssignKernelFn<Gehrd<float>>(lapack_ptr("sgehrd"));
AssignKernelFn<Gehrd<double>>(lapack_ptr("dgehrd"));
AssignKernelFn<Gehrd<std::complex<float>>>(lapack_ptr("cgehrd"));
AssignKernelFn<Gehrd<std::complex<double>>>(lapack_ptr("zgehrd"));
AssignKernelFn<Sytrd<float>>(lapack_ptr("ssytrd"));
AssignKernelFn<Sytrd<double>>(lapack_ptr("dsytrd"));
AssignKernelFn<Sytrd<std::complex<float>>>(lapack_ptr("chetrd"));
AssignKernelFn<Sytrd<std::complex<double>>>(lapack_ptr("zhetrd"));
initialized = true;
}

View File

@ -29,6 +29,16 @@ limitations under the License.
namespace jax {
typedef int lapack_int;
template <typename KernelType>
void AssignKernelFn(void* func) {
KernelType::fn = reinterpret_cast<typename KernelType::FnType*>(func);
}
template <typename KernelType>
void AssignKernelFn(typename KernelType::FnType* func) {
KernelType::fn = func;
}
template <typename T>
struct Trsm {

View File

@ -81,50 +81,60 @@ jax::Sytrd<std::complex<double>>::FnType zhetrd_;
namespace jax {
static auto init = []() -> int {
Trsm<float>::fn = strsm_;
Trsm<double>::fn = dtrsm_;
Trsm<std::complex<float>>::fn = ctrsm_;
Trsm<std::complex<double>>::fn = ztrsm_;
Getrf<float>::fn = sgetrf_;
Getrf<double>::fn = dgetrf_;
Getrf<std::complex<float>>::fn = cgetrf_;
Getrf<std::complex<double>>::fn = zgetrf_;
Geqrf<float>::fn = sgeqrf_;
Geqrf<double>::fn = dgeqrf_;
Geqrf<std::complex<float>>::fn = cgeqrf_;
Geqrf<std::complex<double>>::fn = zgeqrf_;
Orgqr<float>::fn = sorgqr_;
Orgqr<double>::fn = dorgqr_;
Orgqr<std::complex<float>>::fn = cungqr_;
Orgqr<std::complex<double>>::fn = zungqr_;
Potrf<float>::fn = spotrf_;
Potrf<double>::fn = dpotrf_;
Potrf<std::complex<float>>::fn = cpotrf_;
Potrf<std::complex<double>>::fn = zpotrf_;
RealGesdd<float>::fn = sgesdd_;
RealGesdd<double>::fn = dgesdd_;
ComplexGesdd<std::complex<float>>::fn = cgesdd_;
ComplexGesdd<std::complex<double>>::fn = zgesdd_;
RealSyevd<float>::fn = ssyevd_;
RealSyevd<double>::fn = dsyevd_;
ComplexHeevd<std::complex<float>>::fn = cheevd_;
ComplexHeevd<std::complex<double>>::fn = zheevd_;
RealGeev<float>::fn = sgeev_;
RealGeev<double>::fn = dgeev_;
ComplexGeev<std::complex<float>>::fn = cgeev_;
ComplexGeev<std::complex<double>>::fn = zgeev_;
RealGees<float>::fn = sgees_;
RealGees<double>::fn = dgees_;
ComplexGees<std::complex<float>>::fn = cgees_;
ComplexGees<std::complex<double>>::fn = zgees_;
Gehrd<float>::fn = sgehrd_;
Gehrd<double>::fn = dgehrd_;
Gehrd<std::complex<float>>::fn = cgehrd_;
Gehrd<std::complex<double>>::fn = zgehrd_;
Sytrd<float>::fn = ssytrd_;
Sytrd<double>::fn = dsytrd_;
Sytrd<std::complex<float>>::fn = chetrd_;
Sytrd<std::complex<double>>::fn = zhetrd_;
AssignKernelFn<Trsm<float>>(strsm_);
AssignKernelFn<Trsm<double>>(dtrsm_);
AssignKernelFn<Trsm<std::complex<float>>>(ctrsm_);
AssignKernelFn<Trsm<std::complex<double>>>(ztrsm_);
AssignKernelFn<Getrf<float>>(sgetrf_);
AssignKernelFn<Getrf<double>>(dgetrf_);
AssignKernelFn<Getrf<std::complex<float>>>(cgetrf_);
AssignKernelFn<Getrf<std::complex<double>>>(zgetrf_);
AssignKernelFn<Geqrf<float>>(sgeqrf_);
AssignKernelFn<Geqrf<double>>(dgeqrf_);
AssignKernelFn<Geqrf<std::complex<float>>>(cgeqrf_);
AssignKernelFn<Geqrf<std::complex<double>>>(zgeqrf_);
AssignKernelFn<Orgqr<float>>(sorgqr_);
AssignKernelFn<Orgqr<double>>(dorgqr_);
AssignKernelFn<Orgqr<std::complex<float>>>(cungqr_);
AssignKernelFn<Orgqr<std::complex<double>>>(zungqr_);
AssignKernelFn<Potrf<float>>(spotrf_);
AssignKernelFn<Potrf<double>>(dpotrf_);
AssignKernelFn<Potrf<std::complex<float>>>(cpotrf_);
AssignKernelFn<Potrf<std::complex<double>>>(zpotrf_);
AssignKernelFn<RealGesdd<float>>(sgesdd_);
AssignKernelFn<RealGesdd<double>>(dgesdd_);
AssignKernelFn<ComplexGesdd<std::complex<float>>>(cgesdd_);
AssignKernelFn<ComplexGesdd<std::complex<double>>>(zgesdd_);
AssignKernelFn<RealSyevd<float>>(ssyevd_);
AssignKernelFn<RealSyevd<double>>(dsyevd_);
AssignKernelFn<ComplexHeevd<std::complex<float>>>(cheevd_);
AssignKernelFn<ComplexHeevd<std::complex<double>>>(zheevd_);
AssignKernelFn<RealGeev<float>>(sgeev_);
AssignKernelFn<RealGeev<double>>(dgeev_);
AssignKernelFn<ComplexGeev<std::complex<float>>>(cgeev_);
AssignKernelFn<ComplexGeev<std::complex<double>>>(zgeev_);
AssignKernelFn<RealGees<float>>(sgees_);
AssignKernelFn<RealGees<double>>(dgees_);
AssignKernelFn<ComplexGees<std::complex<float>>>(cgees_);
AssignKernelFn<ComplexGees<std::complex<double>>>(zgees_);
AssignKernelFn<Gehrd<float>>(sgehrd_);
AssignKernelFn<Gehrd<double>>(dgehrd_);
AssignKernelFn<Gehrd<std::complex<float>>>(cgehrd_);
AssignKernelFn<Gehrd<std::complex<double>>>(zgehrd_);
AssignKernelFn<Sytrd<float>>(ssytrd_);
AssignKernelFn<Sytrd<double>>(dsytrd_);
AssignKernelFn<Sytrd<std::complex<float>>>(chetrd_);
AssignKernelFn<Sytrd<std::complex<double>>>(zhetrd_);
return 0;
}();