[PJRT C API] Change build wheel script to build a separate package for cuda kernels.

With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
This commit is contained in:
Jieying Luo 2023-11-06 09:05:08 -08:00 committed by jax authors
parent 79ca40ea05
commit 462ef165c4
16 changed files with 347 additions and 51 deletions

View File

@ -582,7 +582,18 @@ def main():
shell(command)
if args.build_gpu_plugin:
build_plugin_command = ([bazel_path] + args.bazel_startup_options +
build_cuda_kernels_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_cuda_kernels_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:
command.append("--editable")
print(" ".join(build_cuda_kernels_command))
shell(build_cuda_kernels_command)
build_pjrt_plugin_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_gpu_plugin_wheel", "--",
f"--output_path={output_path}",
@ -590,8 +601,8 @@ def main():
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:
command.append("--editable")
print(" ".join(build_plugin_command))
shell(build_plugin_command)
print(" ".join(build_pjrt_plugin_command))
shell(build_pjrt_plugin_command)
shell([bazel_path] + args.bazel_startup_options + ["shutdown"])

View File

@ -84,3 +84,14 @@ def build_editable(
)
shutil.rmtree(output_path, ignore_errors=True)
shutil.copytree(sources_path, output_path)
def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str):
src_file = file_dir / "setup.py"
with open(src_file, "r") as f:
content = f.read()
content = content.replace(
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
)
with open(src_file, "w") as f:
f.write(content)

View File

@ -14,6 +14,7 @@
import functools
from functools import partial
import importlib
import operator
import jaxlib.mlir.ir as ir
@ -23,12 +24,19 @@ from .gpu_common_utils import GpuLibNotLinkedError
from jaxlib import xla_client
try:
from .cuda import _linalg as _cuda_linalg # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cuda_linalg = importlib.import_module(
f"{cuda_module_name}._linalg", package="jaxlib"
)
except ImportError:
_cuda_linalg = None
else:
break
if _cuda_linalg:
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_linalg = None
try:
from .rocm import _linalg as _hip_linalg # pytype: disable=import-error

View File

@ -15,6 +15,7 @@
import functools
from functools import partial
import importlib
import itertools
import operator
from typing import Optional, Union
@ -26,12 +27,19 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError
try:
from .cuda import _prng as _cuda_prng # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cuda_prng = importlib.import_module(
f"{cuda_module_name}._prng", package="jaxlib"
)
except ImportError:
_cuda_prng = None
else:
break
if _cuda_prng:
for _name, _value in _cuda_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_prng = None
try:
from .rocm import _prng as _hip_prng # pytype: disable=import-error

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
@ -20,14 +22,17 @@ import numpy as np
from jaxlib import xla_client
from .gpu_common_utils import GpuLibNotLinkedError
try:
from .cuda import _rnn # pytype: disable=import-error
for _name, _value in _rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
except ImportError:
_rnn = None
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib")
except ImportError:
_rnn = None
else:
break
if _rnn:
for _name, _value in _rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes

View File

@ -14,6 +14,7 @@
from collections.abc import Sequence
from functools import partial
import importlib
import math
import jaxlib.mlir.ir as ir
@ -31,17 +32,32 @@ from .hlo_helpers import (
try:
from .cuda import _blas as _cublas # pytype: disable=import-error
except ImportError:
for cuda_module_name in ["jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
except ImportError:
_cublas = None
else:
break
if _cublas:
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cublas = None
try:
from .cuda import _solver as _cusolver # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cusolver = importlib.import_module(
f"{cuda_module_name}._solver", package="jaxlib"
)
except ImportError:
_cusolver = None
else:
break
if _cusolver:
for _name, _value in _cusolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cusolver = None
try:

View File

@ -17,6 +17,7 @@ cusparse wrappers for performing sparse matrix computations in JAX
import math
from functools import partial
import importlib
import jaxlib.mlir.ir as ir
@ -26,11 +27,17 @@ from jaxlib import xla_client
from .hlo_helpers import custom_call, mk_result_types_and_shapes
try:
from .cuda import _sparse as _cusparse # pytype: disable=import-error
except ImportError:
_cusparse = None
else:
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cusparse = importlib.import_module(
f"{cuda_module_name}._sparse", package="jaxlib"
)
except ImportError:
_cusparse = None
else:
break
if _cusparse:
for _name, _value in _cusparse.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")

View File

@ -11,11 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from jaxlib import xla_client
try:
from .cuda import _triton as _cuda_triton # pytype: disable=import-error
for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
try:
_cuda_triton = importlib.import_module(
f"{cuda_module_name}._triton", package="jaxlib"
)
except ImportError:
_cuda_triton = None
else:
break
if _cuda_triton:
xla_client.register_custom_call_target(
"triton_kernel_call", _cuda_triton.get_custom_call(),
platform='CUDA')
@ -27,8 +37,6 @@ try:
get_compute_capability = _cuda_triton.get_compute_capability
get_custom_call = _cuda_triton.get_custom_call
get_serialized_metadata = _cuda_triton.get_serialized_metadata
except ImportError:
_cuda_triton = None
try:
from .rocm import _triton as _hip_triton # pytype: disable=import-error

View File

@ -35,6 +35,7 @@ py_binary(
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",
# TODO(jieying): move it out from jaxlib
"//jaxlib:cuda_plugin_extension",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
@ -53,6 +54,7 @@ py_binary(
"LICENSE.txt",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so",
] + if_cuda([
"//jaxlib:version",
"//jaxlib/cuda:cuda_gpu_support",
"//plugins/cuda:pyproject.toml",
"//plugins/cuda:setup.py",
@ -64,3 +66,21 @@ py_binary(
"@bazel_tools//tools/python/runfiles"
],
)
py_binary(
name = "build_cuda_kernels_wheel",
srcs = ["build_cuda_kernels_wheel.py"],
data = [
"LICENSE.txt",
] + if_cuda([
"//jaxlib:version",
"//jaxlib/cuda:cuda_gpu_support",
"//plugins/cuda:plugin_pyproject.toml",
"//plugins/cuda:plugin_setup.py",
"@local_config_cuda//cuda:cuda-nvvm",
]),
deps = [
"//jax/tools:build_utils",
"@bazel_tools//tools/python/runfiles"
],
)

View File

@ -0,0 +1,120 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script that builds a jax-cuda12-plugin wheel for cuda kernels, intended to be
# run via bazel run as part of the jax cuda plugin build process.
# Most users should not run this script directly; use build.py instead.
import argparse
import functools
import os
import pathlib
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
from jax.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
)
parser.add_argument(
"--cuda_version",
default=None,
required=True,
help="Target CUDA version. Required.",
)
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' jax cuda plugin build instead of a wheel.",
)
args = parser.parse_args()
r = runfiles.Create()
pyext = "pyd" if build_utils.is_windows() else "so"
def write_setup_cfg(sources_path, cpu):
tag = build_utils.platform_tag(cpu)
with open(sources_path / "setup.cfg", "w") as f:
f.write(f"""[metadata]
license_files = LICENSE.txt
[bdist_wheel]
plat-name={tag}
""")
def prepare_wheel(
sources_path: pathlib.Path, *, cpu, cuda_version
):
"""Assembles a source tree for the cuda kernel wheel in `sources_path`."""
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
copy_runfiles(
"__main__/plugins/cuda/plugin_pyproject.toml",
dst_dir=sources_path,
dst_filename="pyproject.toml",
)
copy_runfiles(
"__main__/plugins/cuda/plugin_setup.py",
dst_dir=sources_path,
dst_filename="setup.py",
)
build_utils.update_setup_with_cuda_version(sources_path, cuda_version)
write_setup_cfg(sources_path, cpu)
plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin"
copy_runfiles(
dst_dir=plugin_dir / "nvvm" / "libdevice",
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
)
copy_runfiles(
dst_dir=plugin_dir,
src_files=[
f"__main__/jaxlib/cuda/_solver.{pyext}",
f"__main__/jaxlib/cuda/_blas.{pyext}",
f"__main__/jaxlib/cuda/_linalg.{pyext}",
f"__main__/jaxlib/cuda/_prng.{pyext}",
f"__main__/jaxlib/cuda/_rnn.{pyext}",
f"__main__/jaxlib/cuda/_sparse.{pyext}",
f"__main__/jaxlib/cuda/_triton.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}",
"__main__/jaxlib/version.py",
],
)
# Build wheel for cuda kernels
tmpdir = tempfile.TemporaryDirectory(prefix="jax_cuda_plugin")
sources_path = tmpdir.name
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_wheel(
pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.cuda_version
)
package_name = f"jax cuda{args.cuda_version} plugin"
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
build_utils.build_wheel(sources_path, args.output_path, package_name)
finally:
tmpdir.cleanup()

View File

@ -72,22 +72,11 @@ python-tag=py3
)
def update_setup(file_dir, cuda_version):
src_file = file_dir / "setup.py"
with open(src_file, "r") as f:
content = f.read()
content = content.replace(
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
)
with open(src_file, "w") as f:
f.write(content)
def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
"""Assembles a source tree for the wheel in `sources_path`."""
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
plugin_dir = sources_path / "jax_plugins" / f"xla_cuda_cu{cuda_version}"
plugin_dir = sources_path / "jax_plugins" / f"xla_cuda{cuda_version}"
copy_runfiles(
dst_dir=sources_path,
src_files=[
@ -95,12 +84,13 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
"__main__/plugins/cuda/setup.py",
],
)
update_setup(sources_path, cuda_version)
build_utils.update_setup_with_cuda_version(sources_path, cuda_version)
write_setup_cfg(sources_path, cpu)
copy_runfiles(
dst_dir=plugin_dir,
src_files=[
"__main__/plugins/cuda/__init__.py",
"__main__/jaxlib/version.py",
],
)
copy_runfiles(
@ -113,7 +103,7 @@ def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version):
tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudaplugin")
tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudapjrt")
sources_path = tmpdir.name
try:

View File

@ -216,7 +216,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
],
)
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not include_gpu_plugin_extension:
copy_runfiles(
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],

View File

@ -21,6 +21,8 @@ package(
exports_files([
"__init__.py",
"plugin_pyproject.toml",
"plugin_setup.py",
"pyproject.toml",
"setup.py",
])

View File

@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

View File

@ -0,0 +1,75 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from setuptools import setup
from setuptools.dist import Distribution
__version__ = None
cuda_version = 0 # placeholder
project_name = f"jax-cuda{cuda_version}-plugin"
package_name = f"jax_cuda{cuda_version}_plugin"
def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
'version', os.path.join(pkg_path, 'version.py'))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
_version_module = load_version_module(package_name)
__version__ = _version_module._get_version_for_build()
_cmdclass = _version_module._get_cmdclass(package_name)
cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
if cudnn_version:
__version__ += f"+cudnn{cudnn_version.replace('.', '')}"
class BinaryDistribution(Distribution):
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
def has_ext_modules(self):
return True
setup(
name=project_name,
version=__version__,
cmdclass=_cmdclass,
description="JAX Plugin for NVIDIA GPUs",
long_description="",
long_description_content_type="text/markdown",
author="JAX team",
author_email="jax-dev@google.com",
packages=[package_name],
python_requires=">=3.9",
install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"],
url="https://github.com/google/jax",
license="Apache-2.0",
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
package_data={
package_name: [
"*",
"nvvm/libdevice/libdevice*",
],
},
zip_safe=False,
distclass=BinaryDistribution,
)

View File

@ -12,12 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from setuptools import setup, find_namespace_packages
__version__ = "0.0"
__version__ = None
cuda_version = 0 # placeholder
project_name = f"jax-cuda-plugin-cu{cuda_version}"
package_name = f"jax_plugins.xla_cuda_cu{cuda_version}"
project_name = f"jax-cuda{cuda_version}-pjrt"
package_name = f"jax_plugins.xla_cuda{cuda_version}"
def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
'version', os.path.join(pkg_path, 'version.py'))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
_version_module = load_version_module(f"jax_plugins/xla_cuda{cuda_version}")
__version__ = _version_module._get_version_for_build()
packages = find_namespace_packages(
include=[
@ -48,7 +60,7 @@ setup(
zip_safe=False,
entry_points={
"jax_plugins": [
f"xla_cuda_cu{cuda_version} = {package_name}",
f"xla_cuda{cuda_version} = {package_name}",
],
},
)