[PJRT C API] Set up jax xla cuda package.

Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
This commit is contained in:
Jieying Luo 2023-09-13 16:03:11 -07:00 committed by jax authors
parent a38a152737
commit 91fbf9da26
10 changed files with 463 additions and 129 deletions

View File

@ -225,7 +225,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
enable_tpu, enable_rocm):
enable_tpu, enable_rocm, build_gpu_plugin):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -292,6 +292,10 @@ def write_bazelrc(*, python_bin_path, remote_build,
f.write("build --config=rocm\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if build_gpu_plugin:
f.write(textwrap.dedent("""\
build --noincompatible_remove_legacy_whole_archive
"""))
BANNER = r"""
_ _ __ __
@ -369,6 +373,20 @@ def main():
parser,
"enable_cuda",
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
add_boolean_argument(
parser,
"build_gpu_plugin",
default=False,
help_str=(
"Are we building the gpu plugin in addition to jaxlib? The GPU "
"plugin is still experimental and is not ready for use yet."
),
)
parser.add_argument(
"--gpu_plugin_cuda_version",
choices=["11", "12"],
default="12",
help="Which CUDA major version the gpu plugin is for.")
add_boolean_argument(
parser,
"enable_tpu",
@ -535,6 +553,7 @@ def main():
enable_nccl=args.enable_nccl,
enable_tpu=args.enable_tpu,
enable_rocm=args.enable_rocm,
build_gpu_plugin=args.build_gpu_plugin,
)
if args.configure_only:
@ -542,7 +561,6 @@ def main():
print("\nBuilding XLA and installing it in the jaxlib source tree...")
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_wheel", "--",
@ -552,6 +570,20 @@ def main():
command += ["--editable"]
print(" ".join(command))
shell(command)
if args.build_gpu_plugin:
build_plugin_command = (
" ".join(command)
.replace(
"//jaxlib/tools:build_wheel",
"//jaxlib/tools:build_gpu_plugin_wheel",
)
.split(" ")
)
build_plugin_command += [f"--cuda_version={args.gpu_plugin_cuda_version}"]
print(" ".join(build_plugin_command))
shell(build_plugin_command)
shell([bazel_path] + args.bazel_startup_options + ["shutdown"])

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"py_deps",
"pytype_strict_library",
)
load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"])
@ -45,3 +46,8 @@ py_library(
"//jax/experimental/jax2tf",
] + py_deps("tensorflow_core"),
)
pytype_strict_library(
name = "build_utils",
srcs = ["build_utils.py"],
)

78
jax/tools/build_utils.py Normal file
View File

@ -0,0 +1,78 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for the building JAX related python packages."""
import os
import platform
import shutil
import sys
import subprocess
import glob
def is_windows() -> bool:
return sys.platform.startswith("win32")
def copy_file(
src_file: str,
dst_dir: str,
dst_filename=None,
from_runfiles=True,
runfiles=None,
) -> None:
if from_runfiles:
src_file = runfiles.Rlocation(src_file)
src_filename = os.path.basename(src_file)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
if is_windows():
shutil.copyfile(src_file, dst_file)
else:
shutil.copy(src_file, dst_file)
def platform_tag(cpu: str) -> str:
platform_name, cpu_name = {
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}[(platform.system(), cpu)]
return f"{platform_name}_{cpu_name}"
def build_wheel(sources_path: str, output_path: str, package_name: str) -> None:
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
subprocess.run([sys.executable, "-m", "build", "-n", "-w"],
check=True, cwd=sources_path)
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
output_file = os.path.join(output_path, os.path.basename(wheel))
sys.stderr.write(f"Output wheel: {output_file}\n\n")
sys.stderr.write(f"To install the newly-built {package_name} wheel, run:\n")
sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n")
shutil.copy(wheel, output_path)
def build_editable(
sources_path: str, output_path: str, package_name: str
) -> None:
sys.stderr.write(
f"To install the editable {package_name} build, run:\n\n"
f" pip install -e {output_path}\n\n"
)
shutil.rmtree(output_path, ignore_errors=True)
shutil.copytree(sources_path, output_path)

View File

@ -39,5 +39,27 @@ py_binary(
]) + if_rocm([
"//jaxlib/rocm:rocm_gpu_support",
]),
deps = ["@bazel_tools//tools/python/runfiles"],
deps = [
"//jax/tools:build_utils",
"@bazel_tools//tools/python/runfiles"
],
)
py_binary(
name = "build_gpu_plugin_wheel",
srcs = ["build_gpu_plugin_wheel.py"],
data = [
"LICENSE.txt",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so",
] + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",
"//plugins/cuda:pyproject.toml",
"//plugins/cuda:setup.py",
"//plugins/cuda:__init__.py",
"@local_config_cuda//cuda:cuda-nvvm",
]),
deps = [
"//jax/tools:build_utils",
"@bazel_tools//tools/python/runfiles"
],
)

View File

@ -0,0 +1,127 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script that builds a jax cuda plugin wheel, intended to be run via bazel run
# as part of the jax cuda plugin build process.
# Most users should not run this script directly; use build.py instead.
import argparse
import os
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
from jax.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
"--sources_path",
default=None,
help="Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used.")
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.")
parser.add_argument(
"--cpu",
default=None,
required=True,
help="Target CPU architecture. Required.")
parser.add_argument(
"--cuda_version",
default=None,
required=True,
help="Target CUDA version. Required.",
)
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' jax cuda plugin build instead of a wheel.",
)
args = parser.parse_args()
r = runfiles.Create()
def write_setup_cfg(sources_path, cpu):
tag = build_utils.platform_tag(cpu)
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
f.write(f"""[metadata]
license_files = LICENSE.txt
[bdist_wheel]
plat-name={tag}
python-tag=py3
""")
def update_setup(file_dir, cuda_version):
src_file = os.path.join(file_dir, "setup.py")
with open(os.path.join(src_file), "r") as f:
content = f.read()
content = content.replace(
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
)
with open(os.path.join(src_file), "w") as f:
f.write(content)
def prepare_cuda_plugin_wheel(sources_path, *, cpu, cuda_version):
"""Assembles a source tree for the wheel in `sources_path`."""
jax_plugins_dir = os.path.join(sources_path, "jax_plugins")
os.makedirs(jax_plugins_dir)
plugin_dir = os.path.join(jax_plugins_dir, f"xla_cuda_cu{cuda_version}")
os.makedirs(plugin_dir)
build_utils.copy_file(
"__main__/plugins/cuda/pyproject.toml", dst_dir=sources_path, runfiles=r
)
build_utils.copy_file(
"__main__/plugins/cuda/setup.py", dst_dir=sources_path, runfiles=r
)
update_setup(sources_path, cuda_version)
write_setup_cfg(sources_path, cpu)
build_utils.copy_file(
"__main__/plugins/cuda/__init__.py", dst_dir=plugin_dir, runfiles=r
)
plugin_so_path = r.Rlocation("xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so")
build_utils.copy_file(
plugin_so_path,
dst_dir=plugin_dir,
dst_filename="xla_cuda_plugin.so",
runfiles=r,
)
tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudaplugin")
sources_path = tmpdir.name
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_cuda_plugin_wheel(
sources_path, cpu=args.cpu, cuda_version=args.cuda_version
)
package_name = "jax cuda plugin"
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
build_utils.build_wheel(sources_path, args.output_path, package_name)
finally:
if tmpdir:
tmpdir.cleanup()

View File

@ -19,16 +19,14 @@
import argparse
import functools
import glob
import os
import platform
import re
import shutil
import subprocess
import sys
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
from jax.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
@ -59,29 +57,13 @@ def _is_mac():
return platform.system() == "Darwin"
def _is_windows():
return sys.platform.startswith("win32")
pyext = "pyd" if _is_windows() else "so"
pyext = "pyd" if build_utils.is_windows() else "so"
def exists(src_file):
return r.Rlocation(src_file) is not None
def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True):
if from_runfiles:
src_file = r.Rlocation(src_file)
src_filename = os.path.basename(src_file)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
if _is_windows():
shutil.copyfile(src_file, dst_file)
else:
shutil.copy(src_file, dst_file)
def patch_copy_mlir_import(src_file, dst_dir):
src_file = r.Rlocation(src_file)
src_filename = os.path.basename(src_file)
@ -154,20 +136,8 @@ def verify_mac_libraries_dont_reference_chkstack():
"means that it isn't compatible with older MacOS versions.")
def platform_tag(cpu):
platform_name, cpu_name = {
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}[(platform.system(), cpu)]
return f"{platform_name}_{cpu_name}"
def write_setup_cfg(sources_path, cpu):
tag = platform_tag(cpu)
tag = build_utils.platform_tag(cpu)
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
f.write(f"""[metadata]
license_files = LICENSE.txt
@ -181,12 +151,12 @@ def prepare_wheel(sources_path, *, cpu):
"""Assembles a source tree for the wheel in `sources_path`."""
jaxlib_dir = os.path.join(sources_path, "jaxlib")
os.makedirs(jaxlib_dir)
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)
copy_to_jaxlib = functools.partial(build_utils.copy_file, dst_dir=jaxlib_dir, runfiles=r)
verify_mac_libraries_dont_reference_chkstack()
copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path)
copy_file("__main__/jaxlib/README.md", dst_dir=sources_path)
copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path)
build_utils.copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path, runfiles=r)
build_utils.copy_file("__main__/jaxlib/README.md", dst_dir=sources_path, runfiles=r)
build_utils.copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path, runfiles=r)
write_setup_cfg(sources_path, cpu)
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
@ -207,31 +177,31 @@ def prepare_wheel(sources_path, *, cpu):
copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}")
cpu_dir = os.path.join(jaxlib_dir, "cpu")
os.makedirs(cpu_dir)
copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir)
copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir)
build_utils.copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir, runfiles=r)
cuda_dir = os.path.join(jaxlib_dir, "cuda")
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
os.makedirs(libdevice_dir)
copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir)
copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir)
build_utils.copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir, runfiles=r)
rocm_dir = os.path.join(jaxlib_dir, "rocm")
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
os.makedirs(rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir)
build_utils.copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir, runfiles=r)
if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"):
copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir)
build_utils.copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir, runfiles=r)
if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"):
copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir)
build_utils.copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir, runfiles=r)
mosaic_dir = os.path.join(jaxlib_dir, "mosaic")
mosaic_python_dir = os.path.join(mosaic_dir, "python")
@ -240,7 +210,7 @@ def prepare_wheel(sources_path, *, cpu):
copy_to_jaxlib("__main__/jaxlib/mosaic/python/apply_vector_layout.py", dst_dir=mosaic_python_dir)
copy_to_jaxlib("__main__/jaxlib/mosaic/python/infer_memref_layout.py", dst_dir=mosaic_python_dir)
copy_to_jaxlib("__main__/jaxlib/mosaic/python/tpu.py", dst_dir=mosaic_python_dir)
copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir)
build_utils.copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir, runfiles=r)
# TODO (sharadmv,skyewm): can we avoid patching this file?
patch_copy_mlir_import("__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir)
@ -250,82 +220,61 @@ def prepare_wheel(sources_path, *, cpu):
os.makedirs(mlir_dir)
os.makedirs(mlir_dialects_dir)
os.makedirs(mlir_libs_dir)
copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir)
copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir)
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir)
copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir)
build_utils.copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir, runfiles=r)
build_utils.copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir, runfiles=r)
copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir)
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir)
if _is_windows():
copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir)
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
if build_utils.is_windows():
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir, runfiles=r)
elif _is_mac():
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir)
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir, runfiles=r)
else:
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir, runfiles=r)
patch_copy_xla_extension_stubs(jaxlib_dir)
def build_wheel(sources_path, output_path):
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
subprocess.run([sys.executable, "-m", "build", "-n", "-w"],
check=True, cwd=sources_path)
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
output_file = os.path.join(output_path, os.path.basename(wheel))
sys.stderr.write(f"Output wheel: {output_file}\n\n")
sys.stderr.write("To install the newly-built jaxlib wheel, run:\n")
sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n")
shutil.copy(wheel, output_path)
def build_editable(sources_path, output_path):
sys.stderr.write(
"To install the editable jaxlib build, run:\n\n"
f" pip install -e {output_path}\n\n"
)
shutil.rmtree(output_path, ignore_errors=True)
shutil.copytree(sources_path, output_path)
tmpdir = None
sources_path = args.sources_path
if sources_path is None:
@ -335,10 +284,11 @@ if sources_path is None:
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_wheel(sources_path, cpu=args.cpu)
package_name = "jaxlib"
if args.editable:
build_editable(sources_path, args.output_path)
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
build_wheel(sources_path, args.output_path)
build_utils.build_wheel(sources_path, args.output_path, package_name)
finally:
if tmpdir:
tmpdir.cleanup()

26
plugins/cuda/BUILD.bazel Normal file
View File

@ -0,0 +1,26 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
)
exports_files([
"__init__.py",
"pyproject.toml",
"setup.py",
])

36
plugins/cuda/__init__.py Normal file
View File

@ -0,0 +1,36 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import pathlib
import platform
import sys
import jax._src.xla_bridge as xb
logger = logging.getLogger(__name__)
def initialize():
path = pathlib.Path(__file__).resolve().parent / "xla_cuda_plugin.so"
if not path.exists():
logger.warning(
"WARNING: Native library %s does not exist. This most likely indicates"
" an issue with how %s was built or installed.",
path,
__package__,
)
xb.register_plugin("cuda", priority=500, library_path=str(path))

View File

@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

54
plugins/cuda/setup.py Normal file
View File

@ -0,0 +1,54 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup, find_namespace_packages
__version__ = "0.0"
cuda_version = 0 # placeholder
project_name = f"jax-cuda-plugin-cu{cuda_version}"
package_name = f"jax_plugins.xla_cuda_cu{cuda_version}"
packages = find_namespace_packages(
include=[
package_name,
f"{package_name}.*",
]
)
setup(
name=project_name,
version=__version__,
description="JAX XLA PJRT Plugin for NVIDIA GPUs",
long_description="",
long_description_content_type="text/markdown",
author="JAX team",
author_email="jax-dev@google.com",
packages=packages,
install_requires=[],
url="https://github.com/google/jax",
license="Apache-2.0",
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3",
],
package_data={
package_name: ["xla_cuda_plugin.so"],
},
zip_safe=False,
entry_points={
"jax_plugins": [
f"xla_cuda_cu{cuda_version} = {package_name}",
],
},
)