mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Include ABI tag in jaxlib wheels.
Currently JAX wheels end up with names like: jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl This PR changes the wheel names to: jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl i.e., we include the CPython ABI tag. This simply reflects the status quo in the wheel name, and does not change what jaxlib needs.
This commit is contained in:
parent
5c558d8d24
commit
5b0686f9ea
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup
|
||||
from setuptools.dist import Distribution
|
||||
import os
|
||||
|
||||
__version__ = None
|
||||
@ -29,6 +30,12 @@ cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
|
||||
if cuda_version and cudnn_version:
|
||||
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
|
||||
|
||||
class BinaryDistribution(Distribution):
|
||||
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
|
||||
|
||||
def has_ext_modules(self):
|
||||
return True
|
||||
|
||||
setup(
|
||||
name=project_name,
|
||||
version=__version__,
|
||||
@ -67,4 +74,5 @@ setup(
|
||||
'jaxlib.xla_extension': ['*.pyi'],
|
||||
},
|
||||
zip_safe=False,
|
||||
distclass=BinaryDistribution,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user