mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Refactor kernel function assigment
PiperOrigin-RevId: 641255192
This commit is contained in:
parent
f51af87fc5
commit
5fcd50b7fa
@ -35,12 +35,11 @@ void GetLapackKernelsFromScipy() {
|
||||
auto blas_ptr = [&](const char* name) {
|
||||
return nb::cast<nb::capsule>(blas_capi[name]).data();
|
||||
};
|
||||
Trsm<float>::fn = reinterpret_cast<Trsm<float>::FnType*>(blas_ptr("strsm"));
|
||||
Trsm<double>::fn = reinterpret_cast<Trsm<double>::FnType*>(blas_ptr("dtrsm"));
|
||||
Trsm<std::complex<float>>::fn =
|
||||
reinterpret_cast<Trsm<std::complex<float>>::FnType*>(blas_ptr("ctrsm"));
|
||||
Trsm<std::complex<double>>::fn =
|
||||
reinterpret_cast<Trsm<std::complex<double>>::FnType*>(blas_ptr("ztrsm"));
|
||||
|
||||
AssignKernelFn<Trsm<float>>(blas_ptr("strsm"));
|
||||
AssignKernelFn<Trsm<double>>(blas_ptr("dtrsm"));
|
||||
AssignKernelFn<Trsm<std::complex<float>>>(blas_ptr("ctrsm"));
|
||||
AssignKernelFn<Trsm<std::complex<double>>>(blas_ptr("ztrsm"));
|
||||
|
||||
nb::module_ cython_lapack =
|
||||
nb::module_::import_("scipy.linalg.cython_lapack");
|
||||
@ -48,106 +47,55 @@ void GetLapackKernelsFromScipy() {
|
||||
auto lapack_ptr = [&](const char* name) {
|
||||
return nb::cast<nb::capsule>(lapack_capi[name]).data();
|
||||
};
|
||||
Getrf<float>::fn =
|
||||
reinterpret_cast<Getrf<float>::FnType*>(lapack_ptr("sgetrf"));
|
||||
Getrf<double>::fn =
|
||||
reinterpret_cast<Getrf<double>::FnType*>(lapack_ptr("dgetrf"));
|
||||
Getrf<std::complex<float>>::fn =
|
||||
reinterpret_cast<Getrf<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cgetrf"));
|
||||
Getrf<std::complex<double>>::fn =
|
||||
reinterpret_cast<Getrf<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgetrf"));
|
||||
Geqrf<float>::fn =
|
||||
reinterpret_cast<Geqrf<float>::FnType*>(lapack_ptr("sgeqrf"));
|
||||
Geqrf<double>::fn =
|
||||
reinterpret_cast<Geqrf<double>::FnType*>(lapack_ptr("dgeqrf"));
|
||||
Geqrf<std::complex<float>>::fn =
|
||||
reinterpret_cast<Geqrf<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cgeqrf"));
|
||||
Geqrf<std::complex<double>>::fn =
|
||||
reinterpret_cast<Geqrf<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgeqrf"));
|
||||
Orgqr<float>::fn =
|
||||
reinterpret_cast<Orgqr<float>::FnType*>(lapack_ptr("sorgqr"));
|
||||
Orgqr<double>::fn =
|
||||
reinterpret_cast<Orgqr<double>::FnType*>(lapack_ptr("dorgqr"));
|
||||
Orgqr<std::complex<float>>::fn =
|
||||
reinterpret_cast<Orgqr<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cungqr"));
|
||||
Orgqr<std::complex<double>>::fn =
|
||||
reinterpret_cast<Orgqr<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zungqr"));
|
||||
Potrf<float>::fn =
|
||||
reinterpret_cast<Potrf<float>::FnType*>(lapack_ptr("spotrf"));
|
||||
Potrf<double>::fn =
|
||||
reinterpret_cast<Potrf<double>::FnType*>(lapack_ptr("dpotrf"));
|
||||
Potrf<std::complex<float>>::fn =
|
||||
reinterpret_cast<Potrf<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cpotrf"));
|
||||
Potrf<std::complex<double>>::fn =
|
||||
reinterpret_cast<Potrf<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zpotrf"));
|
||||
RealGesdd<float>::fn =
|
||||
reinterpret_cast<RealGesdd<float>::FnType*>(lapack_ptr("sgesdd"));
|
||||
RealGesdd<double>::fn =
|
||||
reinterpret_cast<RealGesdd<double>::FnType*>(lapack_ptr("dgesdd"));
|
||||
ComplexGesdd<std::complex<float>>::fn =
|
||||
reinterpret_cast<ComplexGesdd<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cgesdd"));
|
||||
ComplexGesdd<std::complex<double>>::fn =
|
||||
reinterpret_cast<ComplexGesdd<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgesdd"));
|
||||
RealSyevd<float>::fn =
|
||||
reinterpret_cast<RealSyevd<float>::FnType*>(lapack_ptr("ssyevd"));
|
||||
RealSyevd<double>::fn =
|
||||
reinterpret_cast<RealSyevd<double>::FnType*>(lapack_ptr("dsyevd"));
|
||||
ComplexHeevd<std::complex<float>>::fn =
|
||||
reinterpret_cast<ComplexHeevd<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cheevd"));
|
||||
ComplexHeevd<std::complex<double>>::fn =
|
||||
reinterpret_cast<ComplexHeevd<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zheevd"));
|
||||
RealGeev<float>::fn =
|
||||
reinterpret_cast<RealGeev<float>::FnType*>(lapack_ptr("sgeev"));
|
||||
RealGeev<double>::fn =
|
||||
reinterpret_cast<RealGeev<double>::FnType*>(lapack_ptr("dgeev"));
|
||||
ComplexGeev<std::complex<float>>::fn =
|
||||
reinterpret_cast<ComplexGeev<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cgeev"));
|
||||
ComplexGeev<std::complex<double>>::fn =
|
||||
reinterpret_cast<ComplexGeev<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgeev"));
|
||||
RealGees<float>::fn =
|
||||
reinterpret_cast<RealGees<float>::FnType*>(lapack_ptr("sgees"));
|
||||
RealGees<double>::fn =
|
||||
reinterpret_cast<RealGees<double>::FnType*>(lapack_ptr("dgees"));
|
||||
ComplexGees<std::complex<float>>::fn =
|
||||
reinterpret_cast<ComplexGees<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cgees"));
|
||||
ComplexGees<std::complex<double>>::fn =
|
||||
reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgees"));
|
||||
Gehrd<float>::fn =
|
||||
reinterpret_cast<Gehrd<float>::FnType*>(lapack_ptr("sgehrd"));
|
||||
Gehrd<double>::fn =
|
||||
reinterpret_cast<Gehrd<double>::FnType*>(lapack_ptr("dgehrd"));
|
||||
Gehrd<std::complex<float>>::fn =
|
||||
reinterpret_cast<Gehrd<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cgehrd"));
|
||||
Gehrd<std::complex<double>>::fn =
|
||||
reinterpret_cast<Gehrd<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgehrd"));
|
||||
Sytrd<float>::fn =
|
||||
reinterpret_cast<Sytrd<float>::FnType*>(lapack_ptr("ssytrd"));
|
||||
Sytrd<double>::fn =
|
||||
reinterpret_cast<Sytrd<double>::FnType*>(lapack_ptr("dsytrd"));
|
||||
Sytrd<std::complex<float>>::fn =
|
||||
reinterpret_cast<Sytrd<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("chetrd"));
|
||||
Sytrd<std::complex<double>>::fn =
|
||||
reinterpret_cast<Sytrd<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zhetrd"));
|
||||
AssignKernelFn<Getrf<float>>(lapack_ptr("sgetrf"));
|
||||
AssignKernelFn<Getrf<double>>(lapack_ptr("dgetrf"));
|
||||
AssignKernelFn<Getrf<std::complex<float>>>(lapack_ptr("cgetrf"));
|
||||
AssignKernelFn<Getrf<std::complex<double>>>(lapack_ptr("zgetrf"));
|
||||
|
||||
AssignKernelFn<Geqrf<float>>(lapack_ptr("sgeqrf"));
|
||||
AssignKernelFn<Geqrf<double>>(lapack_ptr("dgeqrf"));
|
||||
AssignKernelFn<Geqrf<std::complex<float>>>(lapack_ptr("cgeqrf"));
|
||||
AssignKernelFn<Geqrf<std::complex<double>>>(lapack_ptr("zgeqrf"));
|
||||
|
||||
AssignKernelFn<Orgqr<float>>(lapack_ptr("sorgqr"));
|
||||
AssignKernelFn<Orgqr<double>>(lapack_ptr("dorgqr"));
|
||||
AssignKernelFn<Orgqr<std::complex<float>>>(lapack_ptr("cungqr"));
|
||||
AssignKernelFn<Orgqr<std::complex<double>>>(lapack_ptr("zungqr"));
|
||||
|
||||
AssignKernelFn<Potrf<float>>(lapack_ptr("spotrf"));
|
||||
AssignKernelFn<Potrf<double>>(lapack_ptr("dpotrf"));
|
||||
AssignKernelFn<Potrf<std::complex<float>>>(lapack_ptr("cpotrf"));
|
||||
AssignKernelFn<Potrf<std::complex<double>>>(lapack_ptr("zpotrf"));
|
||||
|
||||
AssignKernelFn<RealGesdd<float>>(lapack_ptr("sgesdd"));
|
||||
AssignKernelFn<RealGesdd<double>>(lapack_ptr("dgesdd"));
|
||||
AssignKernelFn<ComplexGesdd<std::complex<float>>>(lapack_ptr("cgesdd"));
|
||||
AssignKernelFn<ComplexGesdd<std::complex<double>>>(lapack_ptr("zgesdd"));
|
||||
|
||||
AssignKernelFn<RealSyevd<float>>(lapack_ptr("ssyevd"));
|
||||
AssignKernelFn<RealSyevd<double>>(lapack_ptr("dsyevd"));
|
||||
AssignKernelFn<ComplexHeevd<std::complex<float>>>(lapack_ptr("cheevd"));
|
||||
AssignKernelFn<ComplexHeevd<std::complex<double>>>(lapack_ptr("zheevd"));
|
||||
|
||||
AssignKernelFn<RealGeev<float>>(lapack_ptr("sgeev"));
|
||||
AssignKernelFn<RealGeev<double>>(lapack_ptr("dgeev"));
|
||||
AssignKernelFn<ComplexGeev<std::complex<float>>>(lapack_ptr("cgeev"));
|
||||
AssignKernelFn<ComplexGeev<std::complex<double>>>(lapack_ptr("zgeev"));
|
||||
|
||||
AssignKernelFn<RealGees<float>>(lapack_ptr("sgees"));
|
||||
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
|
||||
AssignKernelFn<ComplexGees<std::complex<float>>>(lapack_ptr("cgees"));
|
||||
AssignKernelFn<ComplexGees<std::complex<double>>>(lapack_ptr("zgees"));
|
||||
|
||||
AssignKernelFn<Gehrd<float>>(lapack_ptr("sgehrd"));
|
||||
AssignKernelFn<Gehrd<double>>(lapack_ptr("dgehrd"));
|
||||
AssignKernelFn<Gehrd<std::complex<float>>>(lapack_ptr("cgehrd"));
|
||||
AssignKernelFn<Gehrd<std::complex<double>>>(lapack_ptr("zgehrd"));
|
||||
|
||||
AssignKernelFn<Sytrd<float>>(lapack_ptr("ssytrd"));
|
||||
AssignKernelFn<Sytrd<double>>(lapack_ptr("dsytrd"));
|
||||
AssignKernelFn<Sytrd<std::complex<float>>>(lapack_ptr("chetrd"));
|
||||
AssignKernelFn<Sytrd<std::complex<double>>>(lapack_ptr("zhetrd"));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
|
@ -29,6 +29,16 @@ limitations under the License.
|
||||
namespace jax {
|
||||
|
||||
typedef int lapack_int;
|
||||
template <typename KernelType>
|
||||
void AssignKernelFn(void* func) {
|
||||
KernelType::fn = reinterpret_cast<typename KernelType::FnType*>(func);
|
||||
}
|
||||
|
||||
template <typename KernelType>
|
||||
void AssignKernelFn(typename KernelType::FnType* func) {
|
||||
KernelType::fn = func;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct Trsm {
|
||||
|
@ -81,50 +81,60 @@ jax::Sytrd<std::complex<double>>::FnType zhetrd_;
|
||||
namespace jax {
|
||||
|
||||
static auto init = []() -> int {
|
||||
Trsm<float>::fn = strsm_;
|
||||
Trsm<double>::fn = dtrsm_;
|
||||
Trsm<std::complex<float>>::fn = ctrsm_;
|
||||
Trsm<std::complex<double>>::fn = ztrsm_;
|
||||
Getrf<float>::fn = sgetrf_;
|
||||
Getrf<double>::fn = dgetrf_;
|
||||
Getrf<std::complex<float>>::fn = cgetrf_;
|
||||
Getrf<std::complex<double>>::fn = zgetrf_;
|
||||
Geqrf<float>::fn = sgeqrf_;
|
||||
Geqrf<double>::fn = dgeqrf_;
|
||||
Geqrf<std::complex<float>>::fn = cgeqrf_;
|
||||
Geqrf<std::complex<double>>::fn = zgeqrf_;
|
||||
Orgqr<float>::fn = sorgqr_;
|
||||
Orgqr<double>::fn = dorgqr_;
|
||||
Orgqr<std::complex<float>>::fn = cungqr_;
|
||||
Orgqr<std::complex<double>>::fn = zungqr_;
|
||||
Potrf<float>::fn = spotrf_;
|
||||
Potrf<double>::fn = dpotrf_;
|
||||
Potrf<std::complex<float>>::fn = cpotrf_;
|
||||
Potrf<std::complex<double>>::fn = zpotrf_;
|
||||
RealGesdd<float>::fn = sgesdd_;
|
||||
RealGesdd<double>::fn = dgesdd_;
|
||||
ComplexGesdd<std::complex<float>>::fn = cgesdd_;
|
||||
ComplexGesdd<std::complex<double>>::fn = zgesdd_;
|
||||
RealSyevd<float>::fn = ssyevd_;
|
||||
RealSyevd<double>::fn = dsyevd_;
|
||||
ComplexHeevd<std::complex<float>>::fn = cheevd_;
|
||||
ComplexHeevd<std::complex<double>>::fn = zheevd_;
|
||||
RealGeev<float>::fn = sgeev_;
|
||||
RealGeev<double>::fn = dgeev_;
|
||||
ComplexGeev<std::complex<float>>::fn = cgeev_;
|
||||
ComplexGeev<std::complex<double>>::fn = zgeev_;
|
||||
RealGees<float>::fn = sgees_;
|
||||
RealGees<double>::fn = dgees_;
|
||||
ComplexGees<std::complex<float>>::fn = cgees_;
|
||||
ComplexGees<std::complex<double>>::fn = zgees_;
|
||||
Gehrd<float>::fn = sgehrd_;
|
||||
Gehrd<double>::fn = dgehrd_;
|
||||
Gehrd<std::complex<float>>::fn = cgehrd_;
|
||||
Gehrd<std::complex<double>>::fn = zgehrd_;
|
||||
Sytrd<float>::fn = ssytrd_;
|
||||
Sytrd<double>::fn = dsytrd_;
|
||||
Sytrd<std::complex<float>>::fn = chetrd_;
|
||||
Sytrd<std::complex<double>>::fn = zhetrd_;
|
||||
AssignKernelFn<Trsm<float>>(strsm_);
|
||||
AssignKernelFn<Trsm<double>>(dtrsm_);
|
||||
AssignKernelFn<Trsm<std::complex<float>>>(ctrsm_);
|
||||
AssignKernelFn<Trsm<std::complex<double>>>(ztrsm_);
|
||||
|
||||
AssignKernelFn<Getrf<float>>(sgetrf_);
|
||||
AssignKernelFn<Getrf<double>>(dgetrf_);
|
||||
AssignKernelFn<Getrf<std::complex<float>>>(cgetrf_);
|
||||
AssignKernelFn<Getrf<std::complex<double>>>(zgetrf_);
|
||||
|
||||
AssignKernelFn<Geqrf<float>>(sgeqrf_);
|
||||
AssignKernelFn<Geqrf<double>>(dgeqrf_);
|
||||
AssignKernelFn<Geqrf<std::complex<float>>>(cgeqrf_);
|
||||
AssignKernelFn<Geqrf<std::complex<double>>>(zgeqrf_);
|
||||
|
||||
AssignKernelFn<Orgqr<float>>(sorgqr_);
|
||||
AssignKernelFn<Orgqr<double>>(dorgqr_);
|
||||
AssignKernelFn<Orgqr<std::complex<float>>>(cungqr_);
|
||||
AssignKernelFn<Orgqr<std::complex<double>>>(zungqr_);
|
||||
|
||||
AssignKernelFn<Potrf<float>>(spotrf_);
|
||||
AssignKernelFn<Potrf<double>>(dpotrf_);
|
||||
AssignKernelFn<Potrf<std::complex<float>>>(cpotrf_);
|
||||
AssignKernelFn<Potrf<std::complex<double>>>(zpotrf_);
|
||||
|
||||
AssignKernelFn<RealGesdd<float>>(sgesdd_);
|
||||
AssignKernelFn<RealGesdd<double>>(dgesdd_);
|
||||
AssignKernelFn<ComplexGesdd<std::complex<float>>>(cgesdd_);
|
||||
AssignKernelFn<ComplexGesdd<std::complex<double>>>(zgesdd_);
|
||||
|
||||
AssignKernelFn<RealSyevd<float>>(ssyevd_);
|
||||
AssignKernelFn<RealSyevd<double>>(dsyevd_);
|
||||
AssignKernelFn<ComplexHeevd<std::complex<float>>>(cheevd_);
|
||||
AssignKernelFn<ComplexHeevd<std::complex<double>>>(zheevd_);
|
||||
|
||||
AssignKernelFn<RealGeev<float>>(sgeev_);
|
||||
AssignKernelFn<RealGeev<double>>(dgeev_);
|
||||
AssignKernelFn<ComplexGeev<std::complex<float>>>(cgeev_);
|
||||
AssignKernelFn<ComplexGeev<std::complex<double>>>(zgeev_);
|
||||
|
||||
AssignKernelFn<RealGees<float>>(sgees_);
|
||||
AssignKernelFn<RealGees<double>>(dgees_);
|
||||
AssignKernelFn<ComplexGees<std::complex<float>>>(cgees_);
|
||||
AssignKernelFn<ComplexGees<std::complex<double>>>(zgees_);
|
||||
|
||||
AssignKernelFn<Gehrd<float>>(sgehrd_);
|
||||
AssignKernelFn<Gehrd<double>>(dgehrd_);
|
||||
AssignKernelFn<Gehrd<std::complex<float>>>(cgehrd_);
|
||||
AssignKernelFn<Gehrd<std::complex<double>>>(zgehrd_);
|
||||
|
||||
AssignKernelFn<Sytrd<float>>(ssytrd_);
|
||||
AssignKernelFn<Sytrd<double>>(dsytrd_);
|
||||
AssignKernelFn<Sytrd<std::complex<float>>>(chetrd_);
|
||||
AssignKernelFn<Sytrd<std::complex<double>>>(zhetrd_);
|
||||
|
||||
return 0;
|
||||
}();
|
||||
|
Loading…
x
Reference in New Issue
Block a user