Re-exported tensorflow...xla_extension type stubs in jaxlib

The type stubs allow using precise types for XLA primitives instead
of aliasing them to Any.

This commit does not change any type annotations within JAX. That will
be done in a followup. I have manually verified that type stubs are
discoverable by mypy once the new jaxlib is installed by type "checking"

    from jaxlib import xla_extension as xe
    d: xe._Dtype
This commit is contained in:
Sergei Lebedev 2021-03-29 13:07:19 +01:00
parent 30cfd86b80
commit 225ffc30d8
2 changed files with 39 additions and 6 deletions

View File

@ -78,6 +78,38 @@ def copy_file(src_file, dst_dir, dst_filename=None):
else:
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)
_XLA_EXTENSION_STUBS = [
"__init__.pyi",
"jax_jit.pyi",
"ops.pyi",
"outfeed_receiver.pyi",
"pmap_lib.pyi",
"profiler.pyi",
"pytree.pyi",
]
def patch_copy_xla_extension_stubs(dst_dir):
# This file is required by PEP-561. It marks jaxlib as package containing
# type stubs.
with open(os.path.join(dst_dir, "py.typed"), "w") as f:
pass
# The -stubs suffix is required by PEP-561.
xla_extension_dir = os.path.join(dst_dir, "xla_extension-stubs")
os.makedirs(xla_extension_dir)
for stub_name in _XLA_EXTENSION_STUBS:
with open(r.Rlocation(
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)) as f:
src = f.read()
src = src.replace(
"from tensorflow.compiler.xla.python import xla_extension",
"from .. import xla_extension"
)
with open(os.path.join(xla_extension_dir, stub_name), "w") as f:
f.write(src)
def patch_copy_xla_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
src = f.read()
@ -160,6 +192,7 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
else:
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
patch_copy_xla_extension_stubs(jaxlib_dir)
patch_copy_xla_client_py(jaxlib_dir)
if not _is_windows():

View File

@ -13,7 +13,6 @@
# limitations under the License.
from setuptools import setup
from glob import glob
import os
__version__ = None
@ -25,19 +24,20 @@ cuda_version = os.environ.get("JAX_CUDA_VERSION")
if cuda_version:
__version__ += "+cuda" + cuda_version.replace(".", "")
binary_libs = [os.path.basename(f) for f in glob('jaxlib/*.so*')]
binary_libs += [os.path.basename(f) for f in glob('jaxlib/*.pyd*')]
setup(
name='jaxlib',
version=__version__,
description='XLA library for JAX',
author='JAX team',
author_email='jax-dev@google.com',
packages=['jaxlib'],
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
python_requires='>=3.6',
install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers'],
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={'jaxlib': binary_libs},
package_data={
'jaxlib': ['*.so', '*.pyd*', 'py.typed'],
'jaxlib.xla_extension-stubs': ['*.pyi'],
},
zip_safe=False,
)