Include ABI tag in jaxlib wheels.

Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl

This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl

i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
This commit is contained in:
Peter Hawkins 2022-08-17 15:02:12 +00:00
parent 5c558d8d24
commit 5b0686f9ea

View File

@ -13,6 +13,7 @@
# limitations under the License.
from setuptools import setup
from setuptools.dist import Distribution
import os
__version__ = None
@ -29,6 +30,12 @@ cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
if cuda_version and cudnn_version:
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
class BinaryDistribution(Distribution):
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
def has_ext_modules(self):
return True
setup(
name=project_name,
version=__version__,
@ -67,4 +74,5 @@ setup(
'jaxlib.xla_extension': ['*.pyi'],
},
zip_safe=False,
distclass=BinaryDistribution,
)