mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Re-exported tensorflow...xla_extension type stubs in jaxlib
The type stubs allow using precise types for XLA primitives instead of aliasing them to Any. This commit does not change any type annotations within JAX. That will be done in a followup. I have manually verified that type stubs are discoverable by mypy once the new jaxlib is installed by type "checking" from jaxlib import xla_extension as xe d: xe._Dtype
This commit is contained in:
parent
30cfd86b80
commit
225ffc30d8
@ -78,6 +78,38 @@ def copy_file(src_file, dst_dir, dst_filename=None):
|
||||
else:
|
||||
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)
|
||||
|
||||
|
||||
_XLA_EXTENSION_STUBS = [
|
||||
"__init__.pyi",
|
||||
"jax_jit.pyi",
|
||||
"ops.pyi",
|
||||
"outfeed_receiver.pyi",
|
||||
"pmap_lib.pyi",
|
||||
"profiler.pyi",
|
||||
"pytree.pyi",
|
||||
]
|
||||
|
||||
|
||||
def patch_copy_xla_extension_stubs(dst_dir):
|
||||
# This file is required by PEP-561. It marks jaxlib as package containing
|
||||
# type stubs.
|
||||
with open(os.path.join(dst_dir, "py.typed"), "w") as f:
|
||||
pass
|
||||
# The -stubs suffix is required by PEP-561.
|
||||
xla_extension_dir = os.path.join(dst_dir, "xla_extension-stubs")
|
||||
os.makedirs(xla_extension_dir)
|
||||
for stub_name in _XLA_EXTENSION_STUBS:
|
||||
with open(r.Rlocation(
|
||||
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)) as f:
|
||||
src = f.read()
|
||||
src = src.replace(
|
||||
"from tensorflow.compiler.xla.python import xla_extension",
|
||||
"from .. import xla_extension"
|
||||
)
|
||||
with open(os.path.join(xla_extension_dir, stub_name), "w") as f:
|
||||
f.write(src)
|
||||
|
||||
|
||||
def patch_copy_xla_client_py(dst_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
|
||||
src = f.read()
|
||||
@ -160,6 +192,7 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
|
||||
else:
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
patch_copy_xla_client_py(jaxlib_dir)
|
||||
|
||||
if not _is_windows():
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
__version__ = None
|
||||
@ -25,19 +24,20 @@ cuda_version = os.environ.get("JAX_CUDA_VERSION")
|
||||
if cuda_version:
|
||||
__version__ += "+cuda" + cuda_version.replace(".", "")
|
||||
|
||||
binary_libs = [os.path.basename(f) for f in glob('jaxlib/*.so*')]
|
||||
binary_libs += [os.path.basename(f) for f in glob('jaxlib/*.pyd*')]
|
||||
|
||||
setup(
|
||||
name='jaxlib',
|
||||
version=__version__,
|
||||
description='XLA library for JAX',
|
||||
author='JAX team',
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib'],
|
||||
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
|
||||
python_requires='>=3.6',
|
||||
install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
package_data={'jaxlib': binary_libs},
|
||||
package_data={
|
||||
'jaxlib': ['*.so', '*.pyd*', 'py.typed'],
|
||||
'jaxlib.xla_extension-stubs': ['*.pyi'],
|
||||
},
|
||||
zip_safe=False,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user