From 5fcd50b7faadfd14433b642c82710493c7be37a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Fri, 7 Jun 2024 08:17:58 -0700 Subject: [PATCH] Refactor kernel function assigment PiperOrigin-RevId: 641255192 --- jaxlib/cpu/lapack.cc | 160 ++++++++-------------- jaxlib/cpu/lapack_kernels.h | 10 ++ jaxlib/cpu/lapack_kernels_using_lapack.cc | 98 +++++++------ 3 files changed, 118 insertions(+), 150 deletions(-) diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index ddf605fdd..5ef42cfc5 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -35,12 +35,11 @@ void GetLapackKernelsFromScipy() { auto blas_ptr = [&](const char* name) { return nb::cast(blas_capi[name]).data(); }; - Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("strsm")); - Trsm::fn = reinterpret_cast::FnType*>(blas_ptr("dtrsm")); - Trsm>::fn = - reinterpret_cast>::FnType*>(blas_ptr("ctrsm")); - Trsm>::fn = - reinterpret_cast>::FnType*>(blas_ptr("ztrsm")); + + AssignKernelFn>(blas_ptr("strsm")); + AssignKernelFn>(blas_ptr("dtrsm")); + AssignKernelFn>>(blas_ptr("ctrsm")); + AssignKernelFn>>(blas_ptr("ztrsm")); nb::module_ cython_lapack = nb::module_::import_("scipy.linalg.cython_lapack"); @@ -48,106 +47,55 @@ void GetLapackKernelsFromScipy() { auto lapack_ptr = [&](const char* name) { return nb::cast(lapack_capi[name]).data(); }; - Getrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgetrf")); - Getrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgetrf")); - Getrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgetrf")); - Getrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgetrf")); - Geqrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgeqrf")); - Geqrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgeqrf")); - Geqrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgeqrf")); - Geqrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgeqrf")); - Orgqr::fn = - reinterpret_cast::FnType*>(lapack_ptr("sorgqr")); - Orgqr::fn = - reinterpret_cast::FnType*>(lapack_ptr("dorgqr")); - Orgqr>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cungqr")); - Orgqr>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zungqr")); - Potrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("spotrf")); - Potrf::fn = - reinterpret_cast::FnType*>(lapack_ptr("dpotrf")); - Potrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cpotrf")); - Potrf>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zpotrf")); - RealGesdd::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgesdd")); - RealGesdd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgesdd")); - ComplexGesdd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgesdd")); - ComplexGesdd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgesdd")); - RealSyevd::fn = - reinterpret_cast::FnType*>(lapack_ptr("ssyevd")); - RealSyevd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dsyevd")); - ComplexHeevd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cheevd")); - ComplexHeevd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zheevd")); - RealGeev::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgeev")); - RealGeev::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgeev")); - ComplexGeev>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgeev")); - ComplexGeev>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgeev")); - RealGees::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgees")); - RealGees::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgees")); - ComplexGees>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgees")); - ComplexGees>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgees")); - Gehrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("sgehrd")); - Gehrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dgehrd")); - Gehrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("cgehrd")); - Gehrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zgehrd")); - Sytrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("ssytrd")); - Sytrd::fn = - reinterpret_cast::FnType*>(lapack_ptr("dsytrd")); - Sytrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("chetrd")); - Sytrd>::fn = - reinterpret_cast>::FnType*>( - lapack_ptr("zhetrd")); + AssignKernelFn>(lapack_ptr("sgetrf")); + AssignKernelFn>(lapack_ptr("dgetrf")); + AssignKernelFn>>(lapack_ptr("cgetrf")); + AssignKernelFn>>(lapack_ptr("zgetrf")); + + AssignKernelFn>(lapack_ptr("sgeqrf")); + AssignKernelFn>(lapack_ptr("dgeqrf")); + AssignKernelFn>>(lapack_ptr("cgeqrf")); + AssignKernelFn>>(lapack_ptr("zgeqrf")); + + AssignKernelFn>(lapack_ptr("sorgqr")); + AssignKernelFn>(lapack_ptr("dorgqr")); + AssignKernelFn>>(lapack_ptr("cungqr")); + AssignKernelFn>>(lapack_ptr("zungqr")); + + AssignKernelFn>(lapack_ptr("spotrf")); + AssignKernelFn>(lapack_ptr("dpotrf")); + AssignKernelFn>>(lapack_ptr("cpotrf")); + AssignKernelFn>>(lapack_ptr("zpotrf")); + + AssignKernelFn>(lapack_ptr("sgesdd")); + AssignKernelFn>(lapack_ptr("dgesdd")); + AssignKernelFn>>(lapack_ptr("cgesdd")); + AssignKernelFn>>(lapack_ptr("zgesdd")); + + AssignKernelFn>(lapack_ptr("ssyevd")); + AssignKernelFn>(lapack_ptr("dsyevd")); + AssignKernelFn>>(lapack_ptr("cheevd")); + AssignKernelFn>>(lapack_ptr("zheevd")); + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>>(lapack_ptr("cgeev")); + AssignKernelFn>>(lapack_ptr("zgeev")); + + AssignKernelFn>(lapack_ptr("sgees")); + AssignKernelFn>(lapack_ptr("dgees")); + AssignKernelFn>>(lapack_ptr("cgees")); + AssignKernelFn>>(lapack_ptr("zgees")); + + AssignKernelFn>(lapack_ptr("sgehrd")); + AssignKernelFn>(lapack_ptr("dgehrd")); + AssignKernelFn>>(lapack_ptr("cgehrd")); + AssignKernelFn>>(lapack_ptr("zgehrd")); + + AssignKernelFn>(lapack_ptr("ssytrd")); + AssignKernelFn>(lapack_ptr("dsytrd")); + AssignKernelFn>>(lapack_ptr("chetrd")); + AssignKernelFn>>(lapack_ptr("zhetrd")); initialized = true; } diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 4641b772c..84d2251bb 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -29,6 +29,16 @@ limitations under the License. namespace jax { typedef int lapack_int; +template +void AssignKernelFn(void* func) { + KernelType::fn = reinterpret_cast(func); +} + +template +void AssignKernelFn(typename KernelType::FnType* func) { + KernelType::fn = func; +} + template struct Trsm { diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index bc67fc556..4360ec55a 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -81,50 +81,60 @@ jax::Sytrd>::FnType zhetrd_; namespace jax { static auto init = []() -> int { - Trsm::fn = strsm_; - Trsm::fn = dtrsm_; - Trsm>::fn = ctrsm_; - Trsm>::fn = ztrsm_; - Getrf::fn = sgetrf_; - Getrf::fn = dgetrf_; - Getrf>::fn = cgetrf_; - Getrf>::fn = zgetrf_; - Geqrf::fn = sgeqrf_; - Geqrf::fn = dgeqrf_; - Geqrf>::fn = cgeqrf_; - Geqrf>::fn = zgeqrf_; - Orgqr::fn = sorgqr_; - Orgqr::fn = dorgqr_; - Orgqr>::fn = cungqr_; - Orgqr>::fn = zungqr_; - Potrf::fn = spotrf_; - Potrf::fn = dpotrf_; - Potrf>::fn = cpotrf_; - Potrf>::fn = zpotrf_; - RealGesdd::fn = sgesdd_; - RealGesdd::fn = dgesdd_; - ComplexGesdd>::fn = cgesdd_; - ComplexGesdd>::fn = zgesdd_; - RealSyevd::fn = ssyevd_; - RealSyevd::fn = dsyevd_; - ComplexHeevd>::fn = cheevd_; - ComplexHeevd>::fn = zheevd_; - RealGeev::fn = sgeev_; - RealGeev::fn = dgeev_; - ComplexGeev>::fn = cgeev_; - ComplexGeev>::fn = zgeev_; - RealGees::fn = sgees_; - RealGees::fn = dgees_; - ComplexGees>::fn = cgees_; - ComplexGees>::fn = zgees_; - Gehrd::fn = sgehrd_; - Gehrd::fn = dgehrd_; - Gehrd>::fn = cgehrd_; - Gehrd>::fn = zgehrd_; - Sytrd::fn = ssytrd_; - Sytrd::fn = dsytrd_; - Sytrd>::fn = chetrd_; - Sytrd>::fn = zhetrd_; + AssignKernelFn>(strsm_); + AssignKernelFn>(dtrsm_); + AssignKernelFn>>(ctrsm_); + AssignKernelFn>>(ztrsm_); + + AssignKernelFn>(sgetrf_); + AssignKernelFn>(dgetrf_); + AssignKernelFn>>(cgetrf_); + AssignKernelFn>>(zgetrf_); + + AssignKernelFn>(sgeqrf_); + AssignKernelFn>(dgeqrf_); + AssignKernelFn>>(cgeqrf_); + AssignKernelFn>>(zgeqrf_); + + AssignKernelFn>(sorgqr_); + AssignKernelFn>(dorgqr_); + AssignKernelFn>>(cungqr_); + AssignKernelFn>>(zungqr_); + + AssignKernelFn>(spotrf_); + AssignKernelFn>(dpotrf_); + AssignKernelFn>>(cpotrf_); + AssignKernelFn>>(zpotrf_); + + AssignKernelFn>(sgesdd_); + AssignKernelFn>(dgesdd_); + AssignKernelFn>>(cgesdd_); + AssignKernelFn>>(zgesdd_); + + AssignKernelFn>(ssyevd_); + AssignKernelFn>(dsyevd_); + AssignKernelFn>>(cheevd_); + AssignKernelFn>>(zheevd_); + + AssignKernelFn>(sgeev_); + AssignKernelFn>(dgeev_); + AssignKernelFn>>(cgeev_); + AssignKernelFn>>(zgeev_); + + AssignKernelFn>(sgees_); + AssignKernelFn>(dgees_); + AssignKernelFn>>(cgees_); + AssignKernelFn>>(zgees_); + + AssignKernelFn>(sgehrd_); + AssignKernelFn>(dgehrd_); + AssignKernelFn>>(cgehrd_); + AssignKernelFn>>(zgehrd_); + + AssignKernelFn>(ssytrd_); + AssignKernelFn>(dsytrd_); + AssignKernelFn>>(chetrd_); + AssignKernelFn>>(zhetrd_); return 0; }();