From 350b7c56b8b03c489ee603f6a0ed6664f5525c61 Mon Sep 17 00:00:00 2001 From: Shashank Viswanadha Date: Tue, 28 Nov 2023 08:42:19 -0800 Subject: [PATCH] Add python stub files for jaxlib/cpu C++ Python extensions. PiperOrigin-RevId: 585990748 --- jaxlib/cpu/BUILD | 8 ++++++++ jaxlib/cpu/_ducc_fft.pyi | 16 +++++++++++++++ jaxlib/cpu/_lapack.pyi | 44 ++++++++++++++++++++++++++++++++++++++++ jaxlib/ducc_fft.py | 2 +- 4 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 jaxlib/cpu/_ducc_fft.pyi create mode 100644 jaxlib/cpu/_lapack.pyi diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 8db145364..10998d475 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -53,8 +53,12 @@ pybind_extension( "-fexceptions", "-fno-strict-aliasing", ], + enable_stub_generation = True, features = ["-use_header_modules"], module_name = "_lapack", + pytype_srcs = [ + "_lapack.pyi", + ], deps = [ ":lapack_kernels", "//jaxlib:kernel_nanobind_helpers", @@ -90,8 +94,12 @@ pybind_extension( "-fexceptions", "-fno-strict-aliasing", ], + enable_stub_generation = True, features = ["-use_header_modules"], module_name = "_ducc_fft", + pytype_srcs = [ + "_ducc_fft.pyi", + ], deps = [ ":ducc_fft_flatbuffers_cc", ":ducc_fft_kernels", diff --git a/jaxlib/cpu/_ducc_fft.pyi b/jaxlib/cpu/_ducc_fft.pyi new file mode 100644 index 000000000..7d5c3071a --- /dev/null +++ b/jaxlib/cpu/_ducc_fft.pyi @@ -0,0 +1,16 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def dynamic_ducc_fft_descriptor(ndims: int, is_double: bool, fft_type: int, axes: list[int], forward: bool) -> bytes: ... +def registrations() -> dict: ... diff --git a/jaxlib/cpu/_lapack.pyi b/jaxlib/cpu/_lapack.pyi new file mode 100644 index 000000000..416182c93 --- /dev/null +++ b/jaxlib/cpu/_lapack.pyi @@ -0,0 +1,44 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +def cgesdd_rwork_size(*args, **kwargs) -> Any: ... +def cgesdd_work_size(*args, **kwargs) -> Any: ... +def dgesdd_work_size(*args, **kwargs) -> Any: ... +def gesdd_iwork_size(*args, **kwargs) -> Any: ... +def heevd_rwork_size(*args, **kwargs) -> Any: ... +def heevd_work_size(*args, **kwargs) -> Any: ... +def initialize() -> None: ... +def lapack_cgehrd_workspace(*args, **kwargs) -> Any: ... +def lapack_cgeqrf_workspace(*args, **kwargs) -> Any: ... +def lapack_chetrd_workspace(*args, **kwargs) -> Any: ... +def lapack_cungqr_workspace(*args, **kwargs) -> Any: ... +def lapack_dgehrd_workspace(*args, **kwargs) -> Any: ... +def lapack_dgeqrf_workspace(*args, **kwargs) -> Any: ... +def lapack_dorgqr_workspace(*args, **kwargs) -> Any: ... +def lapack_dsytrd_workspace(*args, **kwargs) -> Any: ... +def lapack_sgehrd_workspace(*args, **kwargs) -> Any: ... +def lapack_sgeqrf_workspace(*args, **kwargs) -> Any: ... +def lapack_sorgqr_workspace(*args, **kwargs) -> Any: ... +def lapack_ssytrd_workspace(*args, **kwargs) -> Any: ... +def lapack_zgehrd_workspace(*args, **kwargs) -> Any: ... +def lapack_zgeqrf_workspace(*args, **kwargs) -> Any: ... +def lapack_zhetrd_workspace(*args, **kwargs) -> Any: ... +def lapack_zungqr_workspace(*args, **kwargs) -> Any: ... +def registrations() -> dict: ... +def sgesdd_work_size(*args, **kwargs) -> Any: ... +def syevd_iwork_size(*args, **kwargs) -> Any: ... +def syevd_work_size(*args, **kwargs) -> Any: ... +def zgesdd_work_size(*args, **kwargs) -> Any: ... diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index 36d98762a..be9e3aff6 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -35,7 +35,7 @@ _R2C = 2 def _dynamic_ducc_fft_descriptor( dtype, ndims: int, fft_type: FftType, fft_lengths: list[int] -) -> tuple[bytes]: +) -> bytes: assert len(fft_lengths) >= 1 assert len(fft_lengths) <= ndims, (fft_lengths, ndims)