2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-12-06 21:35:03 -05:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-09-07 08:45:48 -07:00
|
|
|
import importlib
|
|
|
|
import os
|
2018-12-06 21:35:03 -05:00
|
|
|
from setuptools import setup
|
2022-08-17 15:02:12 +00:00
|
|
|
from setuptools.dist import Distribution
|
2018-12-06 21:35:03 -05:00
|
|
|
|
2019-08-04 12:12:53 -04:00
|
|
|
__version__ = None
|
2022-05-10 14:05:09 -07:00
|
|
|
project_name = 'jaxlib'
|
2019-08-04 12:12:53 -04:00
|
|
|
|
2023-09-07 08:45:48 -07:00
|
|
|
def load_version_module(pkg_path):
|
|
|
|
spec = importlib.util.spec_from_file_location(
|
2023-09-07 10:06:32 -07:00
|
|
|
'version', os.path.join(pkg_path, 'version.py'))
|
2023-09-07 08:45:48 -07:00
|
|
|
module = importlib.util.module_from_spec(spec)
|
|
|
|
spec.loader.exec_module(module)
|
|
|
|
return module
|
2023-09-07 10:06:32 -07:00
|
|
|
|
2023-09-07 08:45:48 -07:00
|
|
|
_version_module = load_version_module(project_name)
|
|
|
|
__version__ = _version_module._get_version_for_build()
|
|
|
|
_cmdclass = _version_module._get_cmdclass(project_name)
|
2019-02-25 05:59:45 -08:00
|
|
|
|
2022-07-13 14:03:32 -07:00
|
|
|
with open('README.md') as f:
|
|
|
|
_long_description = f.read()
|
|
|
|
|
2020-08-13 17:30:08 -07:00
|
|
|
cuda_version = os.environ.get("JAX_CUDA_VERSION")
|
2021-10-16 16:13:15 -07:00
|
|
|
cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
|
|
|
|
if cuda_version and cudnn_version:
|
2021-10-18 13:47:08 -07:00
|
|
|
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
|
2020-08-13 17:30:08 -07:00
|
|
|
|
2023-02-28 23:33:24 +00:00
|
|
|
rocm_version = os.environ.get("JAX_ROCM_VERSION")
|
|
|
|
if rocm_version:
|
|
|
|
__version__ += f"+rocm{rocm_version.replace('.', '')}"
|
|
|
|
|
2022-08-17 15:02:12 +00:00
|
|
|
class BinaryDistribution(Distribution):
|
|
|
|
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
|
|
|
|
|
|
|
|
def has_ext_modules(self):
|
|
|
|
return True
|
|
|
|
|
2018-12-06 21:35:03 -05:00
|
|
|
setup(
|
2022-05-10 14:05:09 -07:00
|
|
|
name=project_name,
|
2019-02-25 05:59:45 -08:00
|
|
|
version=__version__,
|
2023-09-07 08:45:48 -07:00
|
|
|
cmdclass=_cmdclass,
|
2018-12-06 21:35:03 -05:00
|
|
|
description='XLA library for JAX',
|
2022-07-13 14:03:32 -07:00
|
|
|
long_description=_long_description,
|
|
|
|
long_description_content_type='text/markdown',
|
2018-12-06 21:35:03 -05:00
|
|
|
author='JAX team',
|
|
|
|
author_email='jax-dev@google.com',
|
2021-08-02 14:31:11 +01:00
|
|
|
packages=['jaxlib', 'jaxlib.xla_extension'],
|
2024-06-26 13:43:15 -04:00
|
|
|
python_requires='>=3.10',
|
2023-10-09 10:33:12 -07:00
|
|
|
install_requires=[
|
2024-07-29 08:49:41 -07:00
|
|
|
'scipy>=1.10',
|
2023-10-09 10:33:12 -07:00
|
|
|
"scipy>=1.11.1; python_version>='3.12'",
|
2024-12-18 08:18:57 -05:00
|
|
|
'numpy>=1.25',
|
2024-06-12 14:39:11 -07:00
|
|
|
'ml_dtypes>=0.2.0',
|
2023-10-09 10:33:12 -07:00
|
|
|
],
|
2024-09-20 07:51:48 -07:00
|
|
|
url='https://github.com/jax-ml/jax',
|
2018-12-06 21:35:03 -05:00
|
|
|
license='Apache-2.0',
|
2022-05-07 13:38:55 +01:00
|
|
|
classifiers=[
|
|
|
|
"Programming Language :: Python :: 3.10",
|
2022-10-14 14:56:07 +00:00
|
|
|
"Programming Language :: Python :: 3.11",
|
2024-03-28 12:59:45 -07:00
|
|
|
"Programming Language :: Python :: 3.12",
|
2024-11-20 08:23:19 -08:00
|
|
|
"Programming Language :: Python :: 3.13",
|
2022-05-07 13:38:55 +01:00
|
|
|
],
|
2021-03-29 13:07:19 +01:00
|
|
|
package_data={
|
2021-11-04 13:29:24 -07:00
|
|
|
'jaxlib': [
|
|
|
|
'*.so',
|
|
|
|
'*.pyd*',
|
|
|
|
'py.typed',
|
2022-10-24 10:02:12 -07:00
|
|
|
'cpu/*',
|
2022-05-06 13:47:23 -07:00
|
|
|
'cuda/*',
|
2021-11-04 13:29:24 -07:00
|
|
|
'cuda/nvvm/libdevice/libdevice*',
|
2023-07-26 03:58:59 -07:00
|
|
|
'mosaic/*.py',
|
2024-04-18 04:03:03 -07:00
|
|
|
'mosaic/gpu/*.so',
|
2023-07-26 03:58:59 -07:00
|
|
|
'mosaic/python/*.py',
|
|
|
|
'mosaic/python/*.so',
|
2021-11-04 13:29:24 -07:00
|
|
|
'mlir/*.py',
|
2024-05-01 19:37:26 +01:00
|
|
|
'mlir/*.pyi',
|
2021-11-04 13:29:24 -07:00
|
|
|
'mlir/dialects/*.py',
|
2024-04-18 04:03:03 -07:00
|
|
|
'mlir/dialects/gpu/*.py',
|
|
|
|
'mlir/dialects/gpu/passes/*.py',
|
2023-12-23 21:01:53 -08:00
|
|
|
'mlir/extras/*.py',
|
2021-11-12 12:05:33 -05:00
|
|
|
'mlir/_mlir_libs/*.dll',
|
|
|
|
'mlir/_mlir_libs/*.dylib',
|
2021-11-04 13:29:24 -07:00
|
|
|
'mlir/_mlir_libs/*.so',
|
|
|
|
'mlir/_mlir_libs/*.pyd',
|
2022-07-20 14:42:56 -07:00
|
|
|
'mlir/_mlir_libs/*.py',
|
2024-05-07 17:02:58 +01:00
|
|
|
'mlir/_mlir_libs/*.pyi',
|
2022-05-06 13:47:23 -07:00
|
|
|
'rocm/*',
|
2024-01-16 14:20:58 +00:00
|
|
|
'triton/*.py',
|
|
|
|
'triton/*.pyi',
|
|
|
|
'triton/*.pyd',
|
|
|
|
'triton/*.so',
|
2024-05-21 10:22:13 -04:00
|
|
|
'include/xla/ffi/api/*.h',
|
2021-11-04 13:29:24 -07:00
|
|
|
],
|
2021-08-02 14:31:11 +01:00
|
|
|
'jaxlib.xla_extension': ['*.pyi'],
|
2021-03-29 13:07:19 +01:00
|
|
|
},
|
|
|
|
zip_safe=False,
|
2022-08-17 15:02:12 +00:00
|
|
|
distclass=BinaryDistribution,
|
2018-12-06 21:35:03 -05:00
|
|
|
)
|