mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels: | |size|wheel name | |----------------------|----|-------------------------------------------------------------------------| |jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl | |cuda pjrt |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl | |cuda kernels |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl| The size of jaxlib with cuda kernels and pjrt is 119M. The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge. PiperOrigin-RevId: 579861480
This commit is contained in:
parent
79ca40ea05
commit
462ef165c4
@ -582,7 +582,18 @@ def main():
|
||||
shell(command)
|
||||
|
||||
if args.build_gpu_plugin:
|
||||
build_plugin_command = ([bazel_path] + args.bazel_startup_options +
|
||||
build_cuda_kernels_command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true"] +
|
||||
["//jaxlib/tools:build_cuda_kernels_wheel", "--",
|
||||
f"--output_path={output_path}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
f"--cuda_version={args.gpu_plugin_cuda_version}"])
|
||||
if args.editable:
|
||||
command.append("--editable")
|
||||
print(" ".join(build_cuda_kernels_command))
|
||||
shell(build_cuda_kernels_command)
|
||||
|
||||
build_pjrt_plugin_command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true"] +
|
||||
["//jaxlib/tools:build_gpu_plugin_wheel", "--",
|
||||
f"--output_path={output_path}",
|
||||
@ -590,8 +601,8 @@ def main():
|
||||
f"--cuda_version={args.gpu_plugin_cuda_version}"])
|
||||
if args.editable:
|
||||
command.append("--editable")
|
||||
print(" ".join(build_plugin_command))
|
||||
shell(build_plugin_command)
|
||||
print(" ".join(build_pjrt_plugin_command))
|
||||
shell(build_pjrt_plugin_command)
|
||||
|
||||
shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
|
||||
|
||||
|
@ -84,3 +84,14 @@ def build_editable(
|
||||
)
|
||||
shutil.rmtree(output_path, ignore_errors=True)
|
||||
shutil.copytree(sources_path, output_path)
|
||||
|
||||
|
||||
def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str):
|
||||
src_file = file_dir / "setup.py"
|
||||
with open(src_file, "r") as f:
|
||||
content = f.read()
|
||||
content = content.replace(
|
||||
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
|
||||
)
|
||||
with open(src_file, "w") as f:
|
||||
f.write(content)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import importlib
|
||||
import operator
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
@ -23,12 +24,19 @@ from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
try:
|
||||
_cuda_linalg = importlib.import_module(
|
||||
f"{cuda_module_name}._linalg", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cuda_linalg = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cuda_linalg:
|
||||
for _name, _value in _cuda_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cuda_linalg = None
|
||||
|
||||
try:
|
||||
from .rocm import _linalg as _hip_linalg # pytype: disable=import-error
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import importlib
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Optional, Union
|
||||
@ -26,12 +27,19 @@ from jaxlib import xla_client
|
||||
from .hlo_helpers import custom_call
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
try:
|
||||
from .cuda import _prng as _cuda_prng # pytype: disable=import-error
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
try:
|
||||
_cuda_prng = importlib.import_module(
|
||||
f"{cuda_module_name}._prng", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cuda_prng = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cuda_prng:
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cuda_prng = None
|
||||
|
||||
try:
|
||||
from .rocm import _prng as _hip_prng # pytype: disable=import-error
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
@ -20,14 +22,17 @@ import numpy as np
|
||||
from jaxlib import xla_client
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
try:
|
||||
from .cuda import _rnn # pytype: disable=import-error
|
||||
for _name, _value in _rnn.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
|
||||
except ImportError:
|
||||
_rnn = None
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
try:
|
||||
_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib")
|
||||
except ImportError:
|
||||
_rnn = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _rnn:
|
||||
for _name, _value in _rnn.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
|
||||
compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import importlib
|
||||
import math
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
@ -31,17 +32,32 @@ from .hlo_helpers import (
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
except ImportError:
|
||||
for cuda_module_name in ["jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
try:
|
||||
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
|
||||
except ImportError:
|
||||
_cublas = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cublas:
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cublas = None
|
||||
|
||||
try:
|
||||
from .cuda import _solver as _cusolver # pytype: disable=import-error
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
try:
|
||||
_cusolver = importlib.import_module(
|
||||
f"{cuda_module_name}._solver", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cusolver = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cusolver:
|
||||
for _name, _value in _cusolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cusolver = None
|
||||
|
||||
|
||||
try:
|
||||
|
@ -17,6 +17,7 @@ cusparse wrappers for performing sparse matrix computations in JAX
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
import importlib
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
@ -26,11 +27,17 @@ from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call, mk_result_types_and_shapes
|
||||
|
||||
try:
|
||||
from .cuda import _sparse as _cusparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
_cusparse = None
|
||||
else:
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
try:
|
||||
_cusparse = importlib.import_module(
|
||||
f"{cuda_module_name}._sparse", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cusparse = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cusparse:
|
||||
for _name, _value in _cusparse.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
|
@ -11,11 +11,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from .cuda import _triton as _cuda_triton # pytype: disable=import-error
|
||||
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
|
||||
try:
|
||||
_cuda_triton = importlib.import_module(
|
||||
f"{cuda_module_name}._triton", package="jaxlib"
|
||||
)
|
||||
except ImportError:
|
||||
_cuda_triton = None
|
||||
else:
|
||||
break
|
||||
|
||||
if _cuda_triton:
|
||||
xla_client.register_custom_call_target(
|
||||
"triton_kernel_call", _cuda_triton.get_custom_call(),
|
||||
platform='CUDA')
|
||||
@ -27,8 +37,6 @@ try:
|
||||
get_compute_capability = _cuda_triton.get_compute_capability
|
||||
get_custom_call = _cuda_triton.get_custom_call
|
||||
get_serialized_metadata = _cuda_triton.get_serialized_metadata
|
||||
except ImportError:
|
||||
_cuda_triton = None
|
||||
|
||||
try:
|
||||
from .rocm import _triton as _hip_triton # pytype: disable=import-error
|
||||
|
@ -35,6 +35,7 @@ py_binary(
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]) + if_cuda([
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
# TODO(jieying): move it out from jaxlib
|
||||
"//jaxlib:cuda_plugin_extension",
|
||||
"@local_config_cuda//cuda:cuda-nvvm",
|
||||
]) + if_rocm([
|
||||
@ -53,6 +54,7 @@ py_binary(
|
||||
"LICENSE.txt",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so",
|
||||
] + if_cuda([
|
||||
"//jaxlib:version",
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
"//plugins/cuda:pyproject.toml",
|
||||
"//plugins/cuda:setup.py",
|
||||
@ -64,3 +66,21 @@ py_binary(
|
||||
"@bazel_tools//tools/python/runfiles"
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "build_cuda_kernels_wheel",
|
||||
srcs = ["build_cuda_kernels_wheel.py"],
|
||||
data = [
|
||||
"LICENSE.txt",
|
||||
] + if_cuda([
|
||||
"//jaxlib:version",
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
"//plugins/cuda:plugin_pyproject.toml",
|
||||
"//plugins/cuda:plugin_setup.py",
|
||||
"@local_config_cuda//cuda:cuda-nvvm",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles"
|
||||
],
|
||||
)
|
||||
|
120
jaxlib/tools/build_cuda_kernels_wheel.py
Normal file
120
jaxlib/tools/build_cuda_kernels_wheel.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Script that builds a jax-cuda12-plugin wheel for cuda kernels, intended to be
|
||||
# run via bazel run as part of the jax cuda plugin build process.
|
||||
|
||||
# Most users should not run this script directly; use build.py instead.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
from jax.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_version",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Target CUDA version. Required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--editable",
|
||||
action="store_true",
|
||||
help="Create an 'editable' jax cuda plugin build instead of a wheel.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
r = runfiles.Create()
|
||||
pyext = "pyd" if build_utils.is_windows() else "so"
|
||||
|
||||
|
||||
def write_setup_cfg(sources_path, cpu):
|
||||
tag = build_utils.platform_tag(cpu)
|
||||
with open(sources_path / "setup.cfg", "w") as f:
|
||||
f.write(f"""[metadata]
|
||||
license_files = LICENSE.txt
|
||||
|
||||
[bdist_wheel]
|
||||
plat-name={tag}
|
||||
""")
|
||||
|
||||
|
||||
def prepare_wheel(
|
||||
sources_path: pathlib.Path, *, cpu, cuda_version
|
||||
):
|
||||
"""Assembles a source tree for the cuda kernel wheel in `sources_path`."""
|
||||
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
|
||||
|
||||
copy_runfiles(
|
||||
"__main__/plugins/cuda/plugin_pyproject.toml",
|
||||
dst_dir=sources_path,
|
||||
dst_filename="pyproject.toml",
|
||||
)
|
||||
copy_runfiles(
|
||||
"__main__/plugins/cuda/plugin_setup.py",
|
||||
dst_dir=sources_path,
|
||||
dst_filename="setup.py",
|
||||
)
|
||||
build_utils.update_setup_with_cuda_version(sources_path, cuda_version)
|
||||
write_setup_cfg(sources_path, cpu)
|
||||
|
||||
plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin"
|
||||
copy_runfiles(
|
||||
dst_dir=plugin_dir / "nvvm" / "libdevice",
|
||||
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=plugin_dir,
|
||||
src_files=[
|
||||
f"__main__/jaxlib/cuda/_solver.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_blas.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_linalg.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_prng.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_rnn.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_sparse.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_triton.{pyext}",
|
||||
f"__main__/jaxlib/cuda/_versions.{pyext}",
|
||||
"__main__/jaxlib/version.py",
|
||||
],
|
||||
)
|
||||
|
||||
# Build wheel for cuda kernels
|
||||
tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin")
|
||||
sources_path = tmpdir.name
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_wheel(
|
||||
pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version
|
||||
)
|
||||
package_name = f"jax cuda{args.cuda_version} plugin"
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name)
|
||||
finally:
|
||||
tmpdir.cleanup()
|
@ -72,22 +72,11 @@ python-tag=py3
|
||||
)
|
||||
|
||||
|
||||
def update_setup(file_dir, cuda_version):
|
||||
src_file = file_dir / "setup.py"
|
||||
with open(src_file, "r") as f:
|
||||
content = f.read()
|
||||
content = content.replace(
|
||||
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
|
||||
)
|
||||
with open(src_file, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
|
||||
"""Assembles a source tree for the wheel in `sources_path`."""
|
||||
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
|
||||
|
||||
plugin_dir = sources_path / "jax_plugins" / f"xla_cuda_cu{cuda_version}"
|
||||
plugin_dir = sources_path / "jax_plugins" / f"xla_cuda{cuda_version}"
|
||||
copy_runfiles(
|
||||
dst_dir=sources_path,
|
||||
src_files=[
|
||||
@ -95,12 +84,13 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
|
||||
"__main__/plugins/cuda/setup.py",
|
||||
],
|
||||
)
|
||||
update_setup(sources_path, cuda_version)
|
||||
build_utils.update_setup_with_cuda_version(sources_path, cuda_version)
|
||||
write_setup_cfg(sources_path, cpu)
|
||||
copy_runfiles(
|
||||
dst_dir=plugin_dir,
|
||||
src_files=[
|
||||
"__main__/plugins/cuda/__init__.py",
|
||||
"__main__/jaxlib/version.py",
|
||||
],
|
||||
)
|
||||
copy_runfiles(
|
||||
@ -113,7 +103,7 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
if sources_path is None:
|
||||
tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudaplugin")
|
||||
tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudapjrt")
|
||||
sources_path = tmpdir.name
|
||||
|
||||
try:
|
||||
|
@ -216,7 +216,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
],
|
||||
)
|
||||
|
||||
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
|
||||
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not include_gpu_plugin_extension:
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
|
||||
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
|
||||
|
@ -21,6 +21,8 @@ package(
|
||||
|
||||
exports_files([
|
||||
"__init__.py",
|
||||
"plugin_pyproject.toml",
|
||||
"plugin_setup.py",
|
||||
"pyproject.toml",
|
||||
"setup.py",
|
||||
])
|
||||
|
3
plugins/cuda/plugin_pyproject.toml
Normal file
3
plugins/cuda/plugin_pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
75
plugins/cuda/plugin_setup.py
Normal file
75
plugins/cuda/plugin_setup.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from setuptools import setup
|
||||
from setuptools.dist import Distribution
|
||||
|
||||
__version__ = None
|
||||
cuda_version = 0 # placeholder
|
||||
project_name = f"jax-cuda{cuda_version}-plugin"
|
||||
package_name = f"jax_cuda{cuda_version}_plugin"
|
||||
|
||||
def load_version_module(pkg_path):
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'version', os.path.join(pkg_path, 'version.py'))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
_version_module = load_version_module(package_name)
|
||||
__version__ = _version_module._get_version_for_build()
|
||||
_cmdclass = _version_module._get_cmdclass(package_name)
|
||||
|
||||
cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
|
||||
if cudnn_version:
|
||||
__version__ += f"+cudnn{cudnn_version.replace('.', '')}"
|
||||
|
||||
class BinaryDistribution(Distribution):
|
||||
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
|
||||
|
||||
def has_ext_modules(self):
|
||||
return True
|
||||
|
||||
setup(
|
||||
name=project_name,
|
||||
version=__version__,
|
||||
cmdclass=_cmdclass,
|
||||
description="JAX Plugin for NVIDIA GPUs",
|
||||
long_description="",
|
||||
long_description_content_type="text/markdown",
|
||||
author="JAX team",
|
||||
author_email="jax-dev@google.com",
|
||||
packages=[package_name],
|
||||
python_requires=">=3.9",
|
||||
install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"],
|
||||
url="https://github.com/google/jax",
|
||||
license="Apache-2.0",
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
],
|
||||
package_data={
|
||||
package_name: [
|
||||
"*",
|
||||
"nvvm/libdevice/libdevice*",
|
||||
],
|
||||
},
|
||||
zip_safe=False,
|
||||
distclass=BinaryDistribution,
|
||||
)
|
@ -12,12 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from setuptools import setup, find_namespace_packages
|
||||
|
||||
__version__ = "0.0"
|
||||
__version__ = None
|
||||
cuda_version = 0 # placeholder
|
||||
project_name = f"jax-cuda-plugin-cu{cuda_version}"
|
||||
package_name = f"jax_plugins.xla_cuda_cu{cuda_version}"
|
||||
project_name = f"jax-cuda{cuda_version}-pjrt"
|
||||
package_name = f"jax_plugins.xla_cuda{cuda_version}"
|
||||
|
||||
def load_version_module(pkg_path):
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'version', os.path.join(pkg_path, 'version.py'))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
_version_module = load_version_module(f"jax_plugins/xla_cuda{cuda_version}")
|
||||
__version__ = _version_module._get_version_for_build()
|
||||
|
||||
packages = find_namespace_packages(
|
||||
include=[
|
||||
@ -48,7 +60,7 @@ setup(
|
||||
zip_safe=False,
|
||||
entry_points={
|
||||
"jax_plugins": [
|
||||
f"xla_cuda_cu{cuda_version} = {package_name}",
|
||||
f"xla_cuda{cuda_version} = {package_name}",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user