mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py. The directory structure in jax repo is: jax/ └── plugins/ └── cuda/ ├── __init__.py ├── pyproject.toml └── setup.py Installed package structure is: jax_plugins/ └── xla_cuda_cu12/ ├── __init__.py └── xla_cuda_plugin.so The major cuda version will be part of the package name. The plugin wheel can be built with command: python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla" PiperOrigin-RevId: 565187954
This commit is contained in:
parent
a38a152737
commit
91fbf9da26
@ -225,7 +225,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
cpu, cuda_compute_capabilities,
|
||||
rocm_amdgpu_targets, bazel_options, target_cpu_features,
|
||||
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
|
||||
enable_tpu, enable_rocm):
|
||||
enable_tpu, enable_rocm, build_gpu_plugin):
|
||||
tf_cuda_paths = []
|
||||
|
||||
with open("../.jax_configure.bazelrc", "w") as f:
|
||||
@ -292,6 +292,10 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
f.write("build --config=rocm\n")
|
||||
if not enable_nccl:
|
||||
f.write("build --config=nonccl\n")
|
||||
if build_gpu_plugin:
|
||||
f.write(textwrap.dedent("""\
|
||||
build --noincompatible_remove_legacy_whole_archive
|
||||
"""))
|
||||
|
||||
BANNER = r"""
|
||||
_ _ __ __
|
||||
@ -369,6 +373,20 @@ def main():
|
||||
parser,
|
||||
"enable_cuda",
|
||||
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"build_gpu_plugin",
|
||||
default=False,
|
||||
help_str=(
|
||||
"Are we building the gpu plugin in addition to jaxlib? The GPU "
|
||||
"plugin is still experimental and is not ready for use yet."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu_plugin_cuda_version",
|
||||
choices=["11", "12"],
|
||||
default="12",
|
||||
help="Which CUDA major version the gpu plugin is for.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_tpu",
|
||||
@ -535,6 +553,7 @@ def main():
|
||||
enable_nccl=args.enable_nccl,
|
||||
enable_tpu=args.enable_tpu,
|
||||
enable_rocm=args.enable_rocm,
|
||||
build_gpu_plugin=args.build_gpu_plugin,
|
||||
)
|
||||
|
||||
if args.configure_only:
|
||||
@ -542,7 +561,6 @@ def main():
|
||||
|
||||
print("\nBuilding XLA and installing it in the jaxlib source tree...")
|
||||
|
||||
|
||||
command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true"] +
|
||||
["//jaxlib/tools:build_wheel", "--",
|
||||
@ -552,6 +570,20 @@ def main():
|
||||
command += ["--editable"]
|
||||
print(" ".join(command))
|
||||
shell(command)
|
||||
|
||||
if args.build_gpu_plugin:
|
||||
build_plugin_command = (
|
||||
" ".join(command)
|
||||
.replace(
|
||||
"//jaxlib/tools:build_wheel",
|
||||
"//jaxlib/tools:build_gpu_plugin_wheel",
|
||||
)
|
||||
.split(" ")
|
||||
)
|
||||
build_plugin_command += [f"--cuda_version={args.gpu_plugin_cuda_version}"]
|
||||
print(" ".join(build_plugin_command))
|
||||
shell(build_plugin_command)
|
||||
|
||||
shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
|
||||
|
||||
|
||||
|
@ -12,11 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
@ -45,3 +46,8 @@ py_library(
|
||||
"//jax/experimental/jax2tf",
|
||||
] + py_deps("tensorflow_core"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "build_utils",
|
||||
srcs = ["build_utils.py"],
|
||||
)
|
||||
|
78
jax/tools/build_utils.py
Normal file
78
jax/tools/build_utils.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for the building JAX related python packages."""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import subprocess
|
||||
import glob
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
return sys.platform.startswith("win32")
|
||||
|
||||
|
||||
def copy_file(
|
||||
src_file: str,
|
||||
dst_dir: str,
|
||||
dst_filename=None,
|
||||
from_runfiles=True,
|
||||
runfiles=None,
|
||||
) -> None:
|
||||
if from_runfiles:
|
||||
src_file = runfiles.Rlocation(src_file)
|
||||
src_filename = os.path.basename(src_file)
|
||||
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
||||
if is_windows():
|
||||
shutil.copyfile(src_file, dst_file)
|
||||
else:
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def platform_tag(cpu: str) -> str:
|
||||
platform_name, cpu_name = {
|
||||
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
|
||||
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
|
||||
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
|
||||
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
("Windows", "AMD64"): ("win", "amd64"),
|
||||
}[(platform.system(), cpu)]
|
||||
return f"{platform_name}_{cpu_name}"
|
||||
|
||||
|
||||
def build_wheel(sources_path: str, output_path: str, package_name: str) -> None:
|
||||
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
||||
subprocess.run([sys.executable, "-m", "build", "-n", "-w"],
|
||||
check=True, cwd=sources_path)
|
||||
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
|
||||
output_file = os.path.join(output_path, os.path.basename(wheel))
|
||||
sys.stderr.write(f"Output wheel: {output_file}\n\n")
|
||||
sys.stderr.write(f"To install the newly-built {package_name} wheel, run:\n")
|
||||
sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n")
|
||||
shutil.copy(wheel, output_path)
|
||||
|
||||
|
||||
def build_editable(
|
||||
sources_path: str, output_path: str, package_name: str
|
||||
) -> None:
|
||||
sys.stderr.write(
|
||||
f"To install the editable {package_name} build, run:\n\n"
|
||||
f" pip install -e {output_path}\n\n"
|
||||
)
|
||||
shutil.rmtree(output_path, ignore_errors=True)
|
||||
shutil.copytree(sources_path, output_path)
|
@ -39,5 +39,27 @@ py_binary(
|
||||
]) + if_rocm([
|
||||
"//jaxlib/rocm:rocm_gpu_support",
|
||||
]),
|
||||
deps = ["@bazel_tools//tools/python/runfiles"],
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles"
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "build_gpu_plugin_wheel",
|
||||
srcs = ["build_gpu_plugin_wheel.py"],
|
||||
data = [
|
||||
"LICENSE.txt",
|
||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so",
|
||||
] + if_cuda([
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
"//plugins/cuda:pyproject.toml",
|
||||
"//plugins/cuda:setup.py",
|
||||
"//plugins/cuda:__init__.py",
|
||||
"@local_config_cuda//cuda:cuda-nvvm",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles"
|
||||
],
|
||||
)
|
||||
|
127
jaxlib/tools/build_gpu_plugin_wheel.py
Normal file
127
jaxlib/tools/build_gpu_plugin_wheel.py
Normal file
@ -0,0 +1,127 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Script that builds a jax cuda plugin wheel, intended to be run via bazel run
|
||||
# as part of the jax cuda plugin build process.
|
||||
|
||||
# Most users should not run this script directly; use build.py instead.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
from jax.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sources_path",
|
||||
default=None,
|
||||
help="Path in which the wheel's sources should be prepared. Optional. If "
|
||||
"omitted, a temporary directory will be used.")
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.")
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Target CPU architecture. Required.")
|
||||
parser.add_argument(
|
||||
"--cuda_version",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Target CUDA version. Required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--editable",
|
||||
action="store_true",
|
||||
help="Create an 'editable' jax cuda plugin build instead of a wheel.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
r = runfiles.Create()
|
||||
|
||||
|
||||
def write_setup_cfg(sources_path, cpu):
|
||||
tag = build_utils.platform_tag(cpu)
|
||||
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
|
||||
f.write(f"""[metadata]
|
||||
license_files = LICENSE.txt
|
||||
|
||||
[bdist_wheel]
|
||||
plat-name={tag}
|
||||
python-tag=py3
|
||||
""")
|
||||
|
||||
|
||||
def update_setup(file_dir, cuda_version):
|
||||
src_file = os.path.join(file_dir, "setup.py")
|
||||
with open(os.path.join(src_file), "r") as f:
|
||||
content = f.read()
|
||||
content = content.replace(
|
||||
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
|
||||
)
|
||||
with open(os.path.join(src_file), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def prepare_cuda_plugin_wheel(sources_path, *, cpu, cuda_version):
|
||||
"""Assembles a source tree for the wheel in `sources_path`."""
|
||||
jax_plugins_dir = os.path.join(sources_path, "jax_plugins")
|
||||
os.makedirs(jax_plugins_dir)
|
||||
plugin_dir = os.path.join(jax_plugins_dir, f"xla_cuda_cu{cuda_version}")
|
||||
os.makedirs(plugin_dir)
|
||||
|
||||
build_utils.copy_file(
|
||||
"__main__/plugins/cuda/pyproject.toml", dst_dir=sources_path, runfiles=r
|
||||
)
|
||||
build_utils.copy_file(
|
||||
"__main__/plugins/cuda/setup.py", dst_dir=sources_path, runfiles=r
|
||||
)
|
||||
update_setup(sources_path, cuda_version)
|
||||
write_setup_cfg(sources_path, cpu)
|
||||
build_utils.copy_file(
|
||||
"__main__/plugins/cuda/__init__.py", dst_dir=plugin_dir, runfiles=r
|
||||
)
|
||||
plugin_so_path = r.Rlocation("xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so")
|
||||
build_utils.copy_file(
|
||||
plugin_so_path,
|
||||
dst_dir=plugin_dir,
|
||||
dst_filename="xla_cuda_plugin.so",
|
||||
runfiles=r,
|
||||
)
|
||||
|
||||
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
if sources_path is None:
|
||||
tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudaplugin")
|
||||
sources_path = tmpdir.name
|
||||
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_cuda_plugin_wheel(
|
||||
sources_path, cpu=args.cpu, cuda_version=args.cuda_version
|
||||
)
|
||||
package_name = "jax cuda plugin"
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
@ -19,16 +19,14 @@
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
from jax.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -59,29 +57,13 @@ def _is_mac():
|
||||
return platform.system() == "Darwin"
|
||||
|
||||
|
||||
def _is_windows():
|
||||
return sys.platform.startswith("win32")
|
||||
|
||||
|
||||
pyext = "pyd" if _is_windows() else "so"
|
||||
pyext = "pyd" if build_utils.is_windows() else "so"
|
||||
|
||||
|
||||
def exists(src_file):
|
||||
return r.Rlocation(src_file) is not None
|
||||
|
||||
|
||||
def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True):
|
||||
if from_runfiles:
|
||||
src_file = r.Rlocation(src_file)
|
||||
|
||||
src_filename = os.path.basename(src_file)
|
||||
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
||||
if _is_windows():
|
||||
shutil.copyfile(src_file, dst_file)
|
||||
else:
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def patch_copy_mlir_import(src_file, dst_dir):
|
||||
src_file = r.Rlocation(src_file)
|
||||
src_filename = os.path.basename(src_file)
|
||||
@ -154,20 +136,8 @@ def verify_mac_libraries_dont_reference_chkstack():
|
||||
"means that it isn't compatible with older MacOS versions.")
|
||||
|
||||
|
||||
def platform_tag(cpu):
|
||||
platform_name, cpu_name = {
|
||||
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
|
||||
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
|
||||
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
|
||||
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
("Windows", "AMD64"): ("win", "amd64"),
|
||||
}[(platform.system(), cpu)]
|
||||
return f"{platform_name}_{cpu_name}"
|
||||
|
||||
|
||||
def write_setup_cfg(sources_path, cpu):
|
||||
tag = platform_tag(cpu)
|
||||
tag = build_utils.platform_tag(cpu)
|
||||
with open(os.path.join(sources_path, "setup.cfg"), "w") as f:
|
||||
f.write(f"""[metadata]
|
||||
license_files = LICENSE.txt
|
||||
@ -181,12 +151,12 @@ def prepare_wheel(sources_path, *, cpu):
|
||||
"""Assembles a source tree for the wheel in `sources_path`."""
|
||||
jaxlib_dir = os.path.join(sources_path, "jaxlib")
|
||||
os.makedirs(jaxlib_dir)
|
||||
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)
|
||||
copy_to_jaxlib = functools.partial(build_utils.copy_file, dst_dir=jaxlib_dir, runfiles=r)
|
||||
|
||||
verify_mac_libraries_dont_reference_chkstack()
|
||||
copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/README.md", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path)
|
||||
build_utils.copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/README.md", dst_dir=sources_path, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path, runfiles=r)
|
||||
write_setup_cfg(sources_path, cpu)
|
||||
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
|
||||
@ -207,31 +177,31 @@ def prepare_wheel(sources_path, *, cpu):
|
||||
copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}")
|
||||
cpu_dir = os.path.join(jaxlib_dir, "cpu")
|
||||
os.makedirs(cpu_dir)
|
||||
copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir)
|
||||
copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir, runfiles=r)
|
||||
|
||||
cuda_dir = os.path.join(jaxlib_dir, "cuda")
|
||||
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
|
||||
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
|
||||
os.makedirs(libdevice_dir)
|
||||
copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir)
|
||||
build_utils.copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
rocm_dir = os.path.join(jaxlib_dir, "rocm")
|
||||
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
|
||||
os.makedirs(rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir)
|
||||
build_utils.copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir, runfiles=r)
|
||||
if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir)
|
||||
build_utils.copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir, runfiles=r)
|
||||
|
||||
mosaic_dir = os.path.join(jaxlib_dir, "mosaic")
|
||||
mosaic_python_dir = os.path.join(mosaic_dir, "python")
|
||||
@ -240,7 +210,7 @@ def prepare_wheel(sources_path, *, cpu):
|
||||
copy_to_jaxlib("__main__/jaxlib/mosaic/python/apply_vector_layout.py", dst_dir=mosaic_python_dir)
|
||||
copy_to_jaxlib("__main__/jaxlib/mosaic/python/infer_memref_layout.py", dst_dir=mosaic_python_dir)
|
||||
copy_to_jaxlib("__main__/jaxlib/mosaic/python/tpu.py", dst_dir=mosaic_python_dir)
|
||||
copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir)
|
||||
build_utils.copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir, runfiles=r)
|
||||
# TODO (sharadmv,skyewm): can we avoid patching this file?
|
||||
patch_copy_mlir_import("__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir)
|
||||
|
||||
@ -250,82 +220,61 @@ def prepare_wheel(sources_path, *, cpu):
|
||||
os.makedirs(mlir_dir)
|
||||
os.makedirs(mlir_dialects_dir)
|
||||
os.makedirs(mlir_libs_dir)
|
||||
copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir)
|
||||
copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir)
|
||||
copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir, runfiles=r)
|
||||
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir)
|
||||
copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir)
|
||||
if _is_windows():
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
if build_utils.is_windows():
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
elif _is_mac():
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
else:
|
||||
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
|
||||
build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir, runfiles=r)
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
|
||||
|
||||
def build_wheel(sources_path, output_path):
|
||||
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
||||
subprocess.run([sys.executable, "-m", "build", "-n", "-w"],
|
||||
check=True, cwd=sources_path)
|
||||
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
|
||||
output_file = os.path.join(output_path, os.path.basename(wheel))
|
||||
sys.stderr.write(f"Output wheel: {output_file}\n\n")
|
||||
sys.stderr.write("To install the newly-built jaxlib wheel, run:\n")
|
||||
sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n")
|
||||
shutil.copy(wheel, output_path)
|
||||
|
||||
|
||||
def build_editable(sources_path, output_path):
|
||||
sys.stderr.write(
|
||||
"To install the editable jaxlib build, run:\n\n"
|
||||
f" pip install -e {output_path}\n\n"
|
||||
)
|
||||
shutil.rmtree(output_path, ignore_errors=True)
|
||||
shutil.copytree(sources_path, output_path)
|
||||
|
||||
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
if sources_path is None:
|
||||
@ -335,10 +284,11 @@ if sources_path is None:
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_wheel(sources_path, cpu=args.cpu)
|
||||
package_name = "jaxlib"
|
||||
if args.editable:
|
||||
build_editable(sources_path, args.output_path)
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
build_wheel(sources_path, args.output_path)
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
||||
|
26
plugins/cuda/BUILD.bazel
Normal file
26
plugins/cuda/BUILD.bazel
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"__init__.py",
|
||||
"pyproject.toml",
|
||||
"setup.py",
|
||||
])
|
36
plugins/cuda/__init__.py
Normal file
36
plugins/cuda/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import jax._src.xla_bridge as xb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize():
|
||||
path = pathlib.Path(__file__).resolve().parent / "xla_cuda_plugin.so"
|
||||
if not path.exists():
|
||||
logger.warning(
|
||||
"WARNING: Native library %s does not exist. This most likely indicates"
|
||||
" an issue with how %s was built or installed.",
|
||||
path,
|
||||
__package__,
|
||||
)
|
||||
|
||||
xb.register_plugin("cuda", priority=500, library_path=str(path))
|
3
plugins/cuda/pyproject.toml
Normal file
3
plugins/cuda/pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
54
plugins/cuda/setup.py
Normal file
54
plugins/cuda/setup.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup, find_namespace_packages
|
||||
|
||||
__version__ = "0.0"
|
||||
cuda_version = 0 # placeholder
|
||||
project_name = f"jax-cuda-plugin-cu{cuda_version}"
|
||||
package_name = f"jax_plugins.xla_cuda_cu{cuda_version}"
|
||||
|
||||
packages = find_namespace_packages(
|
||||
include=[
|
||||
package_name,
|
||||
f"{package_name}.*",
|
||||
]
|
||||
)
|
||||
|
||||
setup(
|
||||
name=project_name,
|
||||
version=__version__,
|
||||
description="JAX XLA PJRT Plugin for NVIDIA GPUs",
|
||||
long_description="",
|
||||
long_description_content_type="text/markdown",
|
||||
author="JAX team",
|
||||
author_email="jax-dev@google.com",
|
||||
packages=packages,
|
||||
install_requires=[],
|
||||
url="https://github.com/google/jax",
|
||||
license="Apache-2.0",
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Programming Language :: Python :: 3",
|
||||
],
|
||||
package_data={
|
||||
package_name: ["xla_cuda_plugin.so"],
|
||||
},
|
||||
zip_safe=False,
|
||||
entry_points={
|
||||
"jax_plugins": [
|
||||
f"xla_cuda_cu{cuda_version} = {package_name}",
|
||||
],
|
||||
},
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user