mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11963 from hawkinsp:abi
PiperOrigin-RevId: 468207347
This commit is contained in:
commit
083cea8b4e
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup
|
||||
from setuptools.dist import Distribution
|
||||
import os
|
||||
|
||||
__version__ = None
|
||||
@ -29,6 +30,12 @@ cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
|
||||
if cuda_version and cudnn_version:
|
||||
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
|
||||
|
||||
class BinaryDistribution(Distribution):
|
||||
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
|
||||
|
||||
def has_ext_modules(self):
|
||||
return True
|
||||
|
||||
setup(
|
||||
name=project_name,
|
||||
version=__version__,
|
||||
@ -67,4 +74,5 @@ setup(
|
||||
'jaxlib.xla_extension': ['*.pyi'],
|
||||
},
|
||||
zip_safe=False,
|
||||
distclass=BinaryDistribution,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user