From bd46e5c96072733eea32e99f3522e6c8e9a82695 Mon Sep 17 00:00:00 2001 From: Shashank Viswanadha Date: Thu, 30 Nov 2023 10:38:38 -0800 Subject: [PATCH] Add `nb::arg` to nanobind definitions to generate better python annotations. PiperOrigin-RevId: 586721759 --- jaxlib/cpu/_lapack.pyi | 54 ++++++++++++++--------------- jaxlib/cpu/lapack.cc | 78 ++++++++++++++++++++++++++++-------------- 2 files changed, 78 insertions(+), 54 deletions(-) diff --git a/jaxlib/cpu/_lapack.pyi b/jaxlib/cpu/_lapack.pyi index 416182c93..460578892 100644 --- a/jaxlib/cpu/_lapack.pyi +++ b/jaxlib/cpu/_lapack.pyi @@ -12,33 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any - -def cgesdd_rwork_size(*args, **kwargs) -> Any: ... -def cgesdd_work_size(*args, **kwargs) -> Any: ... -def dgesdd_work_size(*args, **kwargs) -> Any: ... -def gesdd_iwork_size(*args, **kwargs) -> Any: ... -def heevd_rwork_size(*args, **kwargs) -> Any: ... -def heevd_work_size(*args, **kwargs) -> Any: ... +def cgesdd_rwork_size(m: int, n: int, compute_uv: int) -> int: ... +def cgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... +def dgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... +def gesdd_iwork_size(m: int, n: int) -> int: ... +def heevd_rwork_size(n: int) -> int: ... +def heevd_work_size(n: int) -> int: ... def initialize() -> None: ... -def lapack_cgehrd_workspace(*args, **kwargs) -> Any: ... -def lapack_cgeqrf_workspace(*args, **kwargs) -> Any: ... -def lapack_chetrd_workspace(*args, **kwargs) -> Any: ... -def lapack_cungqr_workspace(*args, **kwargs) -> Any: ... -def lapack_dgehrd_workspace(*args, **kwargs) -> Any: ... -def lapack_dgeqrf_workspace(*args, **kwargs) -> Any: ... -def lapack_dorgqr_workspace(*args, **kwargs) -> Any: ... -def lapack_dsytrd_workspace(*args, **kwargs) -> Any: ... -def lapack_sgehrd_workspace(*args, **kwargs) -> Any: ... -def lapack_sgeqrf_workspace(*args, **kwargs) -> Any: ... -def lapack_sorgqr_workspace(*args, **kwargs) -> Any: ... -def lapack_ssytrd_workspace(*args, **kwargs) -> Any: ... -def lapack_zgehrd_workspace(*args, **kwargs) -> Any: ... -def lapack_zgeqrf_workspace(*args, **kwargs) -> Any: ... -def lapack_zhetrd_workspace(*args, **kwargs) -> Any: ... -def lapack_zungqr_workspace(*args, **kwargs) -> Any: ... +def lapack_cgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... +def lapack_cgeqrf_workspace(m: int, n: int) -> int: ... +def lapack_chetrd_workspace(lda: int, n: int) -> int: ... +def lapack_cungqr_workspace(m: int, n: int, k: int) -> int: ... +def lapack_dgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... +def lapack_dgeqrf_workspace(m: int, n: int) -> int: ... +def lapack_dorgqr_workspace(m: int, n: int, k: int) -> int: ... +def lapack_dsytrd_workspace(lda: int, n: int) -> int: ... +def lapack_sgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... +def lapack_sgeqrf_workspace(m: int, n: int) -> int: ... +def lapack_sorgqr_workspace(m: int, n: int, k: int) -> int: ... +def lapack_ssytrd_workspace(lda: int, n: int) -> int: ... +def lapack_zgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... +def lapack_zgeqrf_workspace(m: int, n: int) -> int: ... +def lapack_zhetrd_workspace(lda: int, n: int) -> int: ... +def lapack_zungqr_workspace(m: int, n: int, k: int) -> int: ... def registrations() -> dict: ... -def sgesdd_work_size(*args, **kwargs) -> Any: ... -def syevd_iwork_size(*args, **kwargs) -> Any: ... -def syevd_work_size(*args, **kwargs) -> Any: ... -def zgesdd_work_size(*args, **kwargs) -> Any: ... +def sgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... +def syevd_iwork_size(n: int) -> int: ... +def syevd_work_size(n: int) -> int: ... +def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index c2879783a..ddf605fdd 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -230,32 +230,58 @@ NB_MODULE(_lapack, m) { m.def("initialize", GetLapackKernelsFromScipy); m.def("registrations", &Registrations); - m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace); - m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace); - m.def("lapack_cgeqrf_workspace", &Geqrf>::Workspace); - m.def("lapack_zgeqrf_workspace", &Geqrf>::Workspace); - m.def("lapack_sorgqr_workspace", &Orgqr::Workspace); - m.def("lapack_dorgqr_workspace", &Orgqr::Workspace); - m.def("lapack_cungqr_workspace", &Orgqr>::Workspace); - m.def("lapack_zungqr_workspace", &Orgqr>::Workspace); - m.def("gesdd_iwork_size", &GesddIworkSize); - m.def("sgesdd_work_size", &RealGesdd::Workspace); - m.def("dgesdd_work_size", &RealGesdd::Workspace); - m.def("cgesdd_rwork_size", &ComplexGesddRworkSize); - m.def("cgesdd_work_size", &ComplexGesdd>::Workspace); - m.def("zgesdd_work_size", &ComplexGesdd>::Workspace); - m.def("syevd_work_size", &SyevdWorkSize); - m.def("syevd_iwork_size", &SyevdIworkSize); - m.def("heevd_work_size", &HeevdWorkSize); - m.def("heevd_rwork_size", &HeevdRworkSize); - m.def("lapack_sgehrd_workspace", &Gehrd::Workspace); - m.def("lapack_dgehrd_workspace", &Gehrd::Workspace); - m.def("lapack_cgehrd_workspace", &Gehrd>::Workspace); - m.def("lapack_zgehrd_workspace", &Gehrd>::Workspace); - m.def("lapack_ssytrd_workspace", &Sytrd::Workspace); - m.def("lapack_dsytrd_workspace", &Sytrd::Workspace); - m.def("lapack_chetrd_workspace", &Sytrd>::Workspace); - m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace); + m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), + nb::arg("n")); + m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), + nb::arg("n")); + m.def("lapack_cgeqrf_workspace", &Geqrf>::Workspace, + nb::arg("m"), nb::arg("n")); + m.def("lapack_zgeqrf_workspace", &Geqrf>::Workspace, + nb::arg("m"), nb::arg("n")); + m.def("lapack_sorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), + nb::arg("n"), nb::arg("k")); + m.def("lapack_dorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), + nb::arg("n"), nb::arg("k")); + m.def("lapack_cungqr_workspace", &Orgqr>::Workspace, + nb::arg("m"), nb::arg("n"), nb::arg("k")); + m.def("lapack_zungqr_workspace", &Orgqr>::Workspace, + nb::arg("m"), nb::arg("n"), nb::arg("k")); + m.def("gesdd_iwork_size", &GesddIworkSize, nb::arg("m"), nb::arg("n")); + m.def("sgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), + nb::arg("n"), nb::arg("job_opt_compute_uv"), + nb::arg("job_opt_full_matrices")); + m.def("dgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), + nb::arg("n"), nb::arg("job_opt_compute_uv"), + nb::arg("job_opt_full_matrices")); + m.def("cgesdd_rwork_size", &ComplexGesddRworkSize, nb::arg("m"), nb::arg("n"), + nb::arg("compute_uv")); + m.def("cgesdd_work_size", &ComplexGesdd>::Workspace, + nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), + nb::arg("job_opt_full_matrices")); + m.def("zgesdd_work_size", &ComplexGesdd>::Workspace, + nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), + nb::arg("job_opt_full_matrices")); + m.def("syevd_work_size", &SyevdWorkSize, nb::arg("n")); + m.def("syevd_iwork_size", &SyevdIworkSize, nb::arg("n")); + m.def("heevd_work_size", &HeevdWorkSize, nb::arg("n")); + m.def("heevd_rwork_size", &HeevdRworkSize, nb::arg("n")); + + m.def("lapack_sgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), + nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); + m.def("lapack_dgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), + nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); + m.def("lapack_cgehrd_workspace", &Gehrd>::Workspace, + nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); + m.def("lapack_zgehrd_workspace", &Gehrd>::Workspace, + nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); + m.def("lapack_ssytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), + nb::arg("n")); + m.def("lapack_dsytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), + nb::arg("n")); + m.def("lapack_chetrd_workspace", &Sytrd>::Workspace, + nb::arg("lda"), nb::arg("n")); + m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace, + nb::arg("lda"), nb::arg("n")); } } // namespace