rocm_jax/jax_plugins/cuda/plugin_setup.py
Peter Hawkins d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00

92 lines
3.2 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from setuptools import setup
from setuptools.dist import Distribution
__version__ = None
cuda_version = 0 # placeholder
project_name = f"jax-cuda{cuda_version}-plugin"
package_name = f"jax_cuda{cuda_version}_plugin"
def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
'version', os.path.join(pkg_path, 'version.py'))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
_version_module = load_version_module(package_name)
__version__ = _version_module._get_version_for_build()
_cmdclass = _version_module._get_cmdclass(package_name)
class BinaryDistribution(Distribution):
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
def has_ext_modules(self):
return True
setup(
name=project_name,
version=__version__,
cmdclass=_cmdclass,
description="JAX Plugin for NVIDIA GPUs",
long_description="",
long_description_content_type="text/markdown",
author="JAX team",
author_email="jax-dev@google.com",
packages=[package_name],
python_requires=">=3.10",
install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"],
extras_require={
'with_cuda': [
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12>=9.1,<10.0",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add a version constraint
# here.
"nvidia-nvjitlink-cu12>=12.1.105",
],
},
url="https://github.com/google/jax",
license="Apache-2.0",
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
package_data={
package_name: [
"*",
"nvvm/libdevice/libdevice*",
],
},
zip_safe=False,
distclass=BinaryDistribution,
)