From 91fbf9da26870da1ecba826d9ddb604ae5c08c1b Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Wed, 13 Sep 2023 16:03:11 -0700 Subject: [PATCH] [PJRT C API] Set up jax xla cuda package. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a build wheel, pyproject.toml and setup.py. The directory structure in jax repo is: jax/ └── plugins/ └── cuda/ ├── __init__.py ├── pyproject.toml └── setup.py Installed package structure is: jax_plugins/ └── xla_cuda_cu12/ ├── __init__.py └── xla_cuda_plugin.so The major cuda version will be part of the package name. The plugin wheel can be built with command: python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla" PiperOrigin-RevId: 565187954 --- build/build.py | 36 ++++- jax/tools/BUILD | 8 +- jax/tools/build_utils.py | 78 ++++++++++ jaxlib/tools/BUILD.bazel | 24 ++- jaxlib/tools/build_gpu_plugin_wheel.py | 127 ++++++++++++++++ jaxlib/tools/build_wheel.py | 200 ++++++++++--------------- plugins/cuda/BUILD.bazel | 26 ++++ plugins/cuda/__init__.py | 36 +++++ plugins/cuda/pyproject.toml | 3 + plugins/cuda/setup.py | 54 +++++++ 10 files changed, 463 insertions(+), 129 deletions(-) create mode 100644 jax/tools/build_utils.py create mode 100644 jaxlib/tools/build_gpu_plugin_wheel.py create mode 100644 plugins/cuda/BUILD.bazel create mode 100644 plugins/cuda/__init__.py create mode 100644 plugins/cuda/pyproject.toml create mode 100644 plugins/cuda/setup.py diff --git a/build/build.py b/build/build.py index 8ad82787c..98d92b463 100755 --- a/build/build.py +++ b/build/build.py @@ -225,7 +225,7 @@ def write_bazelrc(*, python_bin_path, remote_build, cpu, cuda_compute_capabilities, rocm_amdgpu_targets, bazel_options, target_cpu_features, wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl, - enable_tpu, enable_rocm): + enable_tpu, enable_rocm, build_gpu_plugin): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: @@ -292,6 +292,10 @@ def write_bazelrc(*, python_bin_path, remote_build, f.write("build --config=rocm\n") if not enable_nccl: f.write("build --config=nonccl\n") + if build_gpu_plugin: + f.write(textwrap.dedent("""\ + build --noincompatible_remove_legacy_whole_archive + """)) BANNER = r""" _ _ __ __ @@ -369,6 +373,20 @@ def main(): parser, "enable_cuda", help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.") + add_boolean_argument( + parser, + "build_gpu_plugin", + default=False, + help_str=( + "Are we building the gpu plugin in addition to jaxlib? The GPU " + "plugin is still experimental and is not ready for use yet." + ), + ) + parser.add_argument( + "--gpu_plugin_cuda_version", + choices=["11", "12"], + default="12", + help="Which CUDA major version the gpu plugin is for.") add_boolean_argument( parser, "enable_tpu", @@ -535,6 +553,7 @@ def main(): enable_nccl=args.enable_nccl, enable_tpu=args.enable_tpu, enable_rocm=args.enable_rocm, + build_gpu_plugin=args.build_gpu_plugin, ) if args.configure_only: @@ -542,7 +561,6 @@ def main(): print("\nBuilding XLA and installing it in the jaxlib source tree...") - command = ([bazel_path] + args.bazel_startup_options + ["run", "--verbose_failures=true"] + ["//jaxlib/tools:build_wheel", "--", @@ -552,6 +570,20 @@ def main(): command += ["--editable"] print(" ".join(command)) shell(command) + + if args.build_gpu_plugin: + build_plugin_command = ( + " ".join(command) + .replace( + "//jaxlib/tools:build_wheel", + "//jaxlib/tools:build_gpu_plugin_wheel", + ) + .split(" ") + ) + build_plugin_command += [f"--cuda_version={args.gpu_plugin_cuda_version}"] + print(" ".join(build_plugin_command)) + shell(build_plugin_command) + shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 3e0a95029..f4058c0c6 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "py_deps", + "pytype_strict_library", ) +load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) @@ -45,3 +46,8 @@ py_library( "//jax/experimental/jax2tf", ] + py_deps("tensorflow_core"), ) + +pytype_strict_library( + name = "build_utils", + srcs = ["build_utils.py"], +) diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py new file mode 100644 index 000000000..1d5568fe1 --- /dev/null +++ b/jax/tools/build_utils.py @@ -0,0 +1,78 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for the building JAX related python packages.""" + +import os +import platform +import shutil +import sys +import subprocess +import glob + + +def is_windows() -> bool: + return sys.platform.startswith("win32") + + +def copy_file( + src_file: str, + dst_dir: str, + dst_filename=None, + from_runfiles=True, + runfiles=None, +) -> None: + if from_runfiles: + src_file = runfiles.Rlocation(src_file) + src_filename = os.path.basename(src_file) + dst_file = os.path.join(dst_dir, dst_filename or src_filename) + if is_windows(): + shutil.copyfile(src_file, dst_file) + else: + shutil.copy(src_file, dst_file) + + +def platform_tag(cpu: str) -> str: + platform_name, cpu_name = { + ("Linux", "x86_64"): ("manylinux2014", "x86_64"), + ("Linux", "aarch64"): ("manylinux2014", "aarch64"), + ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), + ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), + ("Darwin", "arm64"): ("macosx_11_0", "arm64"), + ("Windows", "AMD64"): ("win", "amd64"), + }[(platform.system(), cpu)] + return f"{platform_name}_{cpu_name}" + + +def build_wheel(sources_path: str, output_path: str, package_name: str) -> None: + """Builds a wheel in `output_path` using the source tree in `sources_path`.""" + subprocess.run([sys.executable, "-m", "build", "-n", "-w"], + check=True, cwd=sources_path) + for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")): + output_file = os.path.join(output_path, os.path.basename(wheel)) + sys.stderr.write(f"Output wheel: {output_file}\n\n") + sys.stderr.write(f"To install the newly-built {package_name} wheel, run:\n") + sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n") + shutil.copy(wheel, output_path) + + +def build_editable( + sources_path: str, output_path: str, package_name: str +) -> None: + sys.stderr.write( + f"To install the editable {package_name} build, run:\n\n" + f" pip install -e {output_path}\n\n" + ) + shutil.rmtree(output_path, ignore_errors=True) + shutil.copytree(sources_path, output_path) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 7e9eaa714..fc9e17f6e 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -39,5 +39,27 @@ py_binary( ]) + if_rocm([ "//jaxlib/rocm:rocm_gpu_support", ]), - deps = ["@bazel_tools//tools/python/runfiles"], + deps = [ + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles" + ], +) + +py_binary( + name = "build_gpu_plugin_wheel", + srcs = ["build_gpu_plugin_wheel.py"], + data = [ + "LICENSE.txt", + "@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so", + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "//plugins/cuda:pyproject.toml", + "//plugins/cuda:setup.py", + "//plugins/cuda:__init__.py", + "@local_config_cuda//cuda:cuda-nvvm", + ]), + deps = [ + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles" + ], ) diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py new file mode 100644 index 000000000..a6da77c0a --- /dev/null +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -0,0 +1,127 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script that builds a jax cuda plugin wheel, intended to be run via bazel run +# as part of the jax cuda plugin build process. + +# Most users should not run this script directly; use build.py instead. + +import argparse +import os +import tempfile + +from bazel_tools.tools.python.runfiles import runfiles +from jax.tools import build_utils + +parser = argparse.ArgumentParser() +parser.add_argument( + "--sources_path", + default=None, + help="Path in which the wheel's sources should be prepared. Optional. If " + "omitted, a temporary directory will be used.") +parser.add_argument( + "--output_path", + default=None, + required=True, + help="Path to which the output wheel should be written. Required.") +parser.add_argument( + "--cpu", + default=None, + required=True, + help="Target CPU architecture. Required.") +parser.add_argument( + "--cuda_version", + default=None, + required=True, + help="Target CUDA version. Required.", +) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax cuda plugin build instead of a wheel.", +) +args = parser.parse_args() + +r = runfiles.Create() + + +def write_setup_cfg(sources_path, cpu): + tag = build_utils.platform_tag(cpu) + with open(os.path.join(sources_path, "setup.cfg"), "w") as f: + f.write(f"""[metadata] +license_files = LICENSE.txt + +[bdist_wheel] +plat-name={tag} +python-tag=py3 +""") + + +def update_setup(file_dir, cuda_version): + src_file = os.path.join(file_dir, "setup.py") + with open(os.path.join(src_file), "r") as f: + content = f.read() + content = content.replace( + "cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}" + ) + with open(os.path.join(src_file), "w") as f: + f.write(content) + + +def prepare_cuda_plugin_wheel(sources_path, *, cpu, cuda_version): + """Assembles a source tree for the wheel in `sources_path`.""" + jax_plugins_dir = os.path.join(sources_path, "jax_plugins") + os.makedirs(jax_plugins_dir) + plugin_dir = os.path.join(jax_plugins_dir, f"xla_cuda_cu{cuda_version}") + os.makedirs(plugin_dir) + + build_utils.copy_file( + "__main__/plugins/cuda/pyproject.toml", dst_dir=sources_path, runfiles=r + ) + build_utils.copy_file( + "__main__/plugins/cuda/setup.py", dst_dir=sources_path, runfiles=r + ) + update_setup(sources_path, cuda_version) + write_setup_cfg(sources_path, cpu) + build_utils.copy_file( + "__main__/plugins/cuda/__init__.py", dst_dir=plugin_dir, runfiles=r + ) + plugin_so_path = r.Rlocation("xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so") + build_utils.copy_file( + plugin_so_path, + dst_dir=plugin_dir, + dst_filename="xla_cuda_plugin.so", + runfiles=r, + ) + + +tmpdir = None +sources_path = args.sources_path +if sources_path is None: + tmpdir = tempfile.TemporaryDirectory(prefix="jaxcudaplugin") + sources_path = tmpdir.name + +try: + os.makedirs(args.output_path, exist_ok=True) + prepare_cuda_plugin_wheel( + sources_path, cpu=args.cpu, cuda_version=args.cuda_version + ) + package_name = "jax cuda plugin" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel(sources_path, args.output_path, package_name) +finally: + if tmpdir: + tmpdir.cleanup() diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 7e6622190..5b12a7d7a 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -19,16 +19,14 @@ import argparse import functools -import glob import os import platform import re -import shutil import subprocess -import sys import tempfile from bazel_tools.tools.python.runfiles import runfiles +from jax.tools import build_utils parser = argparse.ArgumentParser() parser.add_argument( @@ -59,29 +57,13 @@ def _is_mac(): return platform.system() == "Darwin" -def _is_windows(): - return sys.platform.startswith("win32") - - -pyext = "pyd" if _is_windows() else "so" +pyext = "pyd" if build_utils.is_windows() else "so" def exists(src_file): return r.Rlocation(src_file) is not None -def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True): - if from_runfiles: - src_file = r.Rlocation(src_file) - - src_filename = os.path.basename(src_file) - dst_file = os.path.join(dst_dir, dst_filename or src_filename) - if _is_windows(): - shutil.copyfile(src_file, dst_file) - else: - shutil.copy(src_file, dst_file) - - def patch_copy_mlir_import(src_file, dst_dir): src_file = r.Rlocation(src_file) src_filename = os.path.basename(src_file) @@ -154,20 +136,8 @@ def verify_mac_libraries_dont_reference_chkstack(): "means that it isn't compatible with older MacOS versions.") -def platform_tag(cpu): - platform_name, cpu_name = { - ("Linux", "x86_64"): ("manylinux2014", "x86_64"), - ("Linux", "aarch64"): ("manylinux2014", "aarch64"), - ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), - ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), - ("Darwin", "arm64"): ("macosx_11_0", "arm64"), - ("Windows", "AMD64"): ("win", "amd64"), - }[(platform.system(), cpu)] - return f"{platform_name}_{cpu_name}" - - def write_setup_cfg(sources_path, cpu): - tag = platform_tag(cpu) + tag = build_utils.platform_tag(cpu) with open(os.path.join(sources_path, "setup.cfg"), "w") as f: f.write(f"""[metadata] license_files = LICENSE.txt @@ -181,12 +151,12 @@ def prepare_wheel(sources_path, *, cpu): """Assembles a source tree for the wheel in `sources_path`.""" jaxlib_dir = os.path.join(sources_path, "jaxlib") os.makedirs(jaxlib_dir) - copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir) + copy_to_jaxlib = functools.partial(build_utils.copy_file, dst_dir=jaxlib_dir, runfiles=r) verify_mac_libraries_dont_reference_chkstack() - copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path) - copy_file("__main__/jaxlib/README.md", dst_dir=sources_path) - copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path) + build_utils.copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path, runfiles=r) + build_utils.copy_file("__main__/jaxlib/README.md", dst_dir=sources_path, runfiles=r) + build_utils.copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path, runfiles=r) write_setup_cfg(sources_path, cpu) copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py") copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}") @@ -207,31 +177,31 @@ def prepare_wheel(sources_path, *, cpu): copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}") cpu_dir = os.path.join(jaxlib_dir, "cpu") os.makedirs(cpu_dir) - copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir) - copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir) + build_utils.copy_file(f"__main__/jaxlib/cpu/_lapack.{pyext}", dst_dir=cpu_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir, runfiles=r) cuda_dir = os.path.join(jaxlib_dir, "cuda") if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"): libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice") os.makedirs(libdevice_dir) - copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir) - copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir) + build_utils.copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/cuda/_rnn.{pyext}", dst_dir=cuda_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/cuda/_triton.{pyext}", dst_dir=cuda_dir, runfiles=r) rocm_dir = os.path.join(jaxlib_dir, "rocm") if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"): os.makedirs(rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir) + build_utils.copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir, runfiles=r) if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"): - copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir) + build_utils.copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir, runfiles=r) if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"): - copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir) + build_utils.copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir, runfiles=r) mosaic_dir = os.path.join(jaxlib_dir, "mosaic") mosaic_python_dir = os.path.join(mosaic_dir, "python") @@ -240,7 +210,7 @@ def prepare_wheel(sources_path, *, cpu): copy_to_jaxlib("__main__/jaxlib/mosaic/python/apply_vector_layout.py", dst_dir=mosaic_python_dir) copy_to_jaxlib("__main__/jaxlib/mosaic/python/infer_memref_layout.py", dst_dir=mosaic_python_dir) copy_to_jaxlib("__main__/jaxlib/mosaic/python/tpu.py", dst_dir=mosaic_python_dir) - copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir) + build_utils.copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir, runfiles=r) # TODO (sharadmv,skyewm): can we avoid patching this file? patch_copy_mlir_import("__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir) @@ -250,82 +220,61 @@ def prepare_wheel(sources_path, *, cpu): os.makedirs(mlir_dir) os.makedirs(mlir_dialects_dir) os.makedirs(mlir_libs_dir) - copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir) - copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir) - copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir) - copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir) + build_utils.copy_file("__main__/jaxlib/mlir/ir.py", dst_dir=mlir_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/passmanager.py", dst_dir=mlir_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ods_common.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_func_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir, runfiles=r) + build_utils.copy_file("__main__/jaxlib/mlir/dialects/ml_program.py", dst_dir=mlir_dialects_dir, runfiles=r) - copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir) - copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir) - if _is_windows(): - copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir) + build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/__init__.py", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + build_utils.copy_file(f"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}", dst_dir=mlir_libs_dir, runfiles=r) + if build_utils.is_windows(): + build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll", dst_dir=mlir_libs_dir, runfiles=r) elif _is_mac(): - copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir) + build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib", dst_dir=mlir_libs_dir, runfiles=r) else: - copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir) + build_utils.copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir, runfiles=r) patch_copy_xla_extension_stubs(jaxlib_dir) -def build_wheel(sources_path, output_path): - """Builds a wheel in `output_path` using the source tree in `sources_path`.""" - subprocess.run([sys.executable, "-m", "build", "-n", "-w"], - check=True, cwd=sources_path) - for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")): - output_file = os.path.join(output_path, os.path.basename(wheel)) - sys.stderr.write(f"Output wheel: {output_file}\n\n") - sys.stderr.write("To install the newly-built jaxlib wheel, run:\n") - sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n") - shutil.copy(wheel, output_path) - - -def build_editable(sources_path, output_path): - sys.stderr.write( - "To install the editable jaxlib build, run:\n\n" - f" pip install -e {output_path}\n\n" - ) - shutil.rmtree(output_path, ignore_errors=True) - shutil.copytree(sources_path, output_path) - - tmpdir = None sources_path = args.sources_path if sources_path is None: @@ -335,10 +284,11 @@ if sources_path is None: try: os.makedirs(args.output_path, exist_ok=True) prepare_wheel(sources_path, cpu=args.cpu) + package_name = "jaxlib" if args.editable: - build_editable(sources_path, args.output_path) + build_utils.build_editable(sources_path, args.output_path, package_name) else: - build_wheel(sources_path, args.output_path) + build_utils.build_wheel(sources_path, args.output_path, package_name) finally: if tmpdir: tmpdir.cleanup() diff --git a/plugins/cuda/BUILD.bazel b/plugins/cuda/BUILD.bazel new file mode 100644 index 000000000..0d7863738 --- /dev/null +++ b/plugins/cuda/BUILD.bazel @@ -0,0 +1,26 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) + +exports_files([ + "__init__.py", + "pyproject.toml", + "setup.py", +]) diff --git a/plugins/cuda/__init__.py b/plugins/cuda/__init__.py new file mode 100644 index 000000000..6d3086d24 --- /dev/null +++ b/plugins/cuda/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import pathlib +import platform +import sys + +import jax._src.xla_bridge as xb + +logger = logging.getLogger(__name__) + + +def initialize(): + path = pathlib.Path(__file__).resolve().parent / "xla_cuda_plugin.so" + if not path.exists(): + logger.warning( + "WARNING: Native library %s does not exist. This most likely indicates" + " an issue with how %s was built or installed.", + path, + __package__, + ) + + xb.register_plugin("cuda", priority=500, library_path=str(path)) diff --git a/plugins/cuda/pyproject.toml b/plugins/cuda/pyproject.toml new file mode 100644 index 000000000..8fe2f47af --- /dev/null +++ b/plugins/cuda/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py new file mode 100644 index 000000000..563985950 --- /dev/null +++ b/plugins/cuda/setup.py @@ -0,0 +1,54 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup, find_namespace_packages + +__version__ = "0.0" +cuda_version = 0 # placeholder +project_name = f"jax-cuda-plugin-cu{cuda_version}" +package_name = f"jax_plugins.xla_cuda_cu{cuda_version}" + +packages = find_namespace_packages( + include=[ + package_name, + f"{package_name}.*", + ] +) + +setup( + name=project_name, + version=__version__, + description="JAX XLA PJRT Plugin for NVIDIA GPUs", + long_description="", + long_description_content_type="text/markdown", + author="JAX team", + author_email="jax-dev@google.com", + packages=packages, + install_requires=[], + url="https://github.com/google/jax", + license="Apache-2.0", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + ], + package_data={ + package_name: ["xla_cuda_plugin.so"], + }, + zip_safe=False, + entry_points={ + "jax_plugins": [ + f"xla_cuda_cu{cuda_version} = {package_name}", + ], + }, +)