mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add nb::arg
to nanobind definitions to generate better python annotations.
PiperOrigin-RevId: 586721759
This commit is contained in:
parent
11d7a2b860
commit
bd46e5c960
@ -12,33 +12,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
def cgesdd_rwork_size(*args, **kwargs) -> Any: ...
|
||||
def cgesdd_work_size(*args, **kwargs) -> Any: ...
|
||||
def dgesdd_work_size(*args, **kwargs) -> Any: ...
|
||||
def gesdd_iwork_size(*args, **kwargs) -> Any: ...
|
||||
def heevd_rwork_size(*args, **kwargs) -> Any: ...
|
||||
def heevd_work_size(*args, **kwargs) -> Any: ...
|
||||
def cgesdd_rwork_size(m: int, n: int, compute_uv: int) -> int: ...
|
||||
def cgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
|
||||
def dgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
|
||||
def gesdd_iwork_size(m: int, n: int) -> int: ...
|
||||
def heevd_rwork_size(n: int) -> int: ...
|
||||
def heevd_work_size(n: int) -> int: ...
|
||||
def initialize() -> None: ...
|
||||
def lapack_cgehrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_cgeqrf_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_chetrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_cungqr_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_dgehrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_dgeqrf_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_dorgqr_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_dsytrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_sgehrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_sgeqrf_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_sorgqr_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_ssytrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_zgehrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_zgeqrf_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_zhetrd_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_zungqr_workspace(*args, **kwargs) -> Any: ...
|
||||
def lapack_cgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
|
||||
def lapack_cgeqrf_workspace(m: int, n: int) -> int: ...
|
||||
def lapack_chetrd_workspace(lda: int, n: int) -> int: ...
|
||||
def lapack_cungqr_workspace(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_dgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
|
||||
def lapack_dgeqrf_workspace(m: int, n: int) -> int: ...
|
||||
def lapack_dorgqr_workspace(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_dsytrd_workspace(lda: int, n: int) -> int: ...
|
||||
def lapack_sgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
|
||||
def lapack_sgeqrf_workspace(m: int, n: int) -> int: ...
|
||||
def lapack_sorgqr_workspace(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_ssytrd_workspace(lda: int, n: int) -> int: ...
|
||||
def lapack_zgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ...
|
||||
def lapack_zgeqrf_workspace(m: int, n: int) -> int: ...
|
||||
def lapack_zhetrd_workspace(lda: int, n: int) -> int: ...
|
||||
def lapack_zungqr_workspace(m: int, n: int, k: int) -> int: ...
|
||||
def registrations() -> dict: ...
|
||||
def sgesdd_work_size(*args, **kwargs) -> Any: ...
|
||||
def syevd_iwork_size(*args, **kwargs) -> Any: ...
|
||||
def syevd_work_size(*args, **kwargs) -> Any: ...
|
||||
def zgesdd_work_size(*args, **kwargs) -> Any: ...
|
||||
def sgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
|
||||
def syevd_iwork_size(n: int) -> int: ...
|
||||
def syevd_work_size(n: int) -> int: ...
|
||||
def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ...
|
||||
|
@ -230,32 +230,58 @@ NB_MODULE(_lapack, m) {
|
||||
m.def("initialize", GetLapackKernelsFromScipy);
|
||||
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace);
|
||||
m.def("lapack_dgeqrf_workspace", &Geqrf<double>::Workspace);
|
||||
m.def("lapack_cgeqrf_workspace", &Geqrf<std::complex<float>>::Workspace);
|
||||
m.def("lapack_zgeqrf_workspace", &Geqrf<std::complex<double>>::Workspace);
|
||||
m.def("lapack_sorgqr_workspace", &Orgqr<float>::Workspace);
|
||||
m.def("lapack_dorgqr_workspace", &Orgqr<double>::Workspace);
|
||||
m.def("lapack_cungqr_workspace", &Orgqr<std::complex<float>>::Workspace);
|
||||
m.def("lapack_zungqr_workspace", &Orgqr<std::complex<double>>::Workspace);
|
||||
m.def("gesdd_iwork_size", &GesddIworkSize);
|
||||
m.def("sgesdd_work_size", &RealGesdd<float>::Workspace);
|
||||
m.def("dgesdd_work_size", &RealGesdd<double>::Workspace);
|
||||
m.def("cgesdd_rwork_size", &ComplexGesddRworkSize);
|
||||
m.def("cgesdd_work_size", &ComplexGesdd<std::complex<float>>::Workspace);
|
||||
m.def("zgesdd_work_size", &ComplexGesdd<std::complex<double>>::Workspace);
|
||||
m.def("syevd_work_size", &SyevdWorkSize);
|
||||
m.def("syevd_iwork_size", &SyevdIworkSize);
|
||||
m.def("heevd_work_size", &HeevdWorkSize);
|
||||
m.def("heevd_rwork_size", &HeevdRworkSize);
|
||||
m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace);
|
||||
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace);
|
||||
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace);
|
||||
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace);
|
||||
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace);
|
||||
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace);
|
||||
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace);
|
||||
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace);
|
||||
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_dgeqrf_workspace", &Geqrf<double>::Workspace, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_cgeqrf_workspace", &Geqrf<std::complex<float>>::Workspace,
|
||||
nb::arg("m"), nb::arg("n"));
|
||||
m.def("lapack_zgeqrf_workspace", &Geqrf<std::complex<double>>::Workspace,
|
||||
nb::arg("m"), nb::arg("n"));
|
||||
m.def("lapack_sorgqr_workspace", &Orgqr<float>::Workspace, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
m.def("lapack_dorgqr_workspace", &Orgqr<double>::Workspace, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
m.def("lapack_cungqr_workspace", &Orgqr<std::complex<float>>::Workspace,
|
||||
nb::arg("m"), nb::arg("n"), nb::arg("k"));
|
||||
m.def("lapack_zungqr_workspace", &Orgqr<std::complex<double>>::Workspace,
|
||||
nb::arg("m"), nb::arg("n"), nb::arg("k"));
|
||||
m.def("gesdd_iwork_size", &GesddIworkSize, nb::arg("m"), nb::arg("n"));
|
||||
m.def("sgesdd_work_size", &RealGesdd<float>::Workspace, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("job_opt_compute_uv"),
|
||||
nb::arg("job_opt_full_matrices"));
|
||||
m.def("dgesdd_work_size", &RealGesdd<double>::Workspace, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("job_opt_compute_uv"),
|
||||
nb::arg("job_opt_full_matrices"));
|
||||
m.def("cgesdd_rwork_size", &ComplexGesddRworkSize, nb::arg("m"), nb::arg("n"),
|
||||
nb::arg("compute_uv"));
|
||||
m.def("cgesdd_work_size", &ComplexGesdd<std::complex<float>>::Workspace,
|
||||
nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"),
|
||||
nb::arg("job_opt_full_matrices"));
|
||||
m.def("zgesdd_work_size", &ComplexGesdd<std::complex<double>>::Workspace,
|
||||
nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"),
|
||||
nb::arg("job_opt_full_matrices"));
|
||||
m.def("syevd_work_size", &SyevdWorkSize, nb::arg("n"));
|
||||
m.def("syevd_iwork_size", &SyevdIworkSize, nb::arg("n"));
|
||||
m.def("heevd_work_size", &HeevdWorkSize, nb::arg("n"));
|
||||
m.def("heevd_rwork_size", &HeevdRworkSize, nb::arg("n"));
|
||||
|
||||
m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace, nb::arg("lda"),
|
||||
nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
|
||||
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace, nb::arg("lda"),
|
||||
nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
|
||||
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace,
|
||||
nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
|
||||
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace,
|
||||
nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi"));
|
||||
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace, nb::arg("lda"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace, nb::arg("lda"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace,
|
||||
nb::arg("lda"), nb::arg("n"));
|
||||
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace,
|
||||
nb::arg("lda"), nb::arg("n"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user