Merge pull request #11963 from hawkinsp:abi

PiperOrigin-RevId: 468207347
This commit is contained in:
jax authors 2022-08-17 08:42:58 -07:00
commit 083cea8b4e

View File

@ -13,6 +13,7 @@
# limitations under the License.
from setuptools import setup
from setuptools.dist import Distribution
import os
__version__ = None
@ -29,6 +30,12 @@ cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
if cuda_version and cudnn_version:
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
class BinaryDistribution(Distribution):
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
def has_ext_modules(self):
return True
setup(
name=project_name,
version=__version__,
@ -67,4 +74,5 @@ setup(
'jaxlib.xla_extension': ['*.pyi'],
},
zip_safe=False,
distclass=BinaryDistribution,
)