diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 84cc697d1..83d0b4b25 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -62,6 +62,11 @@ def platform_tag(cpu: str) -> str: }[(platform.system(), cpu)] return f"{platform_name}_{cpu_name}" +def get_githash(jaxlib_git_hash): + if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): + with open(jaxlib_git_hash, "r") as f: + return f.readline().strip() + return jaxlib_git_hash def build_wheel( sources_path: str, output_path: str, package_name: str, git_hash: str = "" diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 3c812d62c..d6811bf66 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -305,6 +305,95 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _jax_wheel_impl(ctx): + executable = ctx.executable.wheel_binary + + output = ctx.actions.declare_directory(ctx.label.name) + args = ctx.actions.args() + args.add("--output_path", output.path) # required argument + args.add("--cpu", ctx.attr.platform_tag) # required argument + jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path + args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + + if ctx.attr.enable_cuda: + args.add("--enable-cuda", "True") + if ctx.attr.platform_version == "": + fail("platform_version must be set to a valid cuda version for cuda wheels") + args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + if ctx.attr.enable_rocm: + args.add("--enable-rocm", "True") + if ctx.attr.platform_version == "": + fail("platform_version must be set to a valid rocm version for rocm wheels") + args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + if ctx.attr.skip_gpu_kernels: + args.add("--skip_gpu_kernels") + + args.set_param_file_format("flag_per_line") + args.use_param_file("@%s", use_always = False) + ctx.actions.run( + arguments = [args], + inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], + outputs = [output], + executable = executable, + ) + return [DefaultInfo(files = depset(direct = [output]))] + +_jax_wheel = rule( + attrs = { + "wheel_binary": attr.label( + default = Label("//jaxlib/tools:build_wheel"), + executable = True, + # b/365588895 Investigate cfg = "exec" for multi platform builds + cfg = "target", + ), + "platform_tag": attr.string(mandatory = True), + "git_hash": attr.label(allow_single_file = True), + "enable_cuda": attr.bool(default = False), + # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. + "platform_version": attr.string(mandatory = True, default = ""), + "skip_gpu_kernels": attr.bool(default = False), + "enable_rocm": attr.bool(default = False), + }, + implementation = _jax_wheel_impl, + executable = False, +) + +def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): + """Create jax artifact wheels. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the name of the wheel + wheel_binary: the binary to use to build the wheel + enable_cuda: whether to build a cuda wheel + platform_version: the cuda version to use for the wheel + + Returns: + A directory containing the wheel + """ + _jax_wheel( + name = name, + wheel_binary = wheel_binary, + enable_cuda = enable_cuda, + platform_version = platform_version, + # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to + # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to + # the git hash file needs to be created first. + git_hash = select({ + "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", + "//conditions:default": None, + }), + # Following the convention in jax/tools/build_utils.py. + # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. + platform_tag = select({ + "//jaxlib/tools:macos_arm64": "arm64", + "//jaxlib/tools:win_amd64": "AMD64", + "//jaxlib/tools:arm64": "aarch64", + "@platforms//cpu:x86_64": "x86_64", + }), + ) + jax_test_file_visibility = [] def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 4553dc1e3..48dc03cfb 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -14,9 +14,11 @@ # JAX is Autograd and XLA +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows", "jax_py_test") +load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") licenses(["notice"]) # Apache 2 @@ -30,11 +32,11 @@ py_binary( "//jaxlib", "//jaxlib:README.md", "//jaxlib:setup.py", + "@xla//xla/ffi/api:api.h", + "@xla//xla/ffi/api:c_api.h", + "@xla//xla/ffi/api:ffi.h", "@xla//xla/python:xla_client.py", "@xla//xla/python:xla_extension", - "@xla//xla/ffi/api:c_api.h", - "@xla//xla/ffi/api:api.h", - "@xla//xla/ffi/api:ffi.h", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + if_cuda([ @@ -44,11 +46,11 @@ py_binary( "//jaxlib/rocm:rocm_gpu_support", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", ], ) @@ -57,7 +59,7 @@ jax_py_test( srcs = ["build_wheel_test.py"], data = [":build_wheel"], deps = [ - "@bazel_tools//tools/python/runfiles", + "@bazel_tools//tools/python/runfiles", ], ) @@ -102,11 +104,11 @@ py_binary( "//jax_plugins/rocm:__init__.py", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", ], ) @@ -131,10 +133,75 @@ py_binary( "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", ], ) + +selects.config_setting_group( + name = "macos", + match_any = [ + "@platforms//os:osx", + "@platforms//os:macos", + ], +) + +selects.config_setting_group( + name = "arm64", + match_any = [ + "@platforms//cpu:aarch64", + "@platforms//cpu:arm64", + ], +) + +selects.config_setting_group( + name = "macos_arm64", + match_all = [ + ":arm64", + ":macos", + ], +) + +selects.config_setting_group( + name = "win_amd64", + match_all = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", + ], +) + +string_flag( + name = "jaxlib_git_hash", + build_setting_default = "", +) + +config_setting( + name = "jaxlib_git_hash_nightly_or_release", + flag_values = { + ":jaxlib_git_hash": "nightly", + }, +) + +jax_wheel( + name = "jaxlib_wheel", + wheel_binary = ":build_wheel", +) + +jax_wheel( + name = "jax_cuda_plugin_wheel", + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_kernels_wheel", +) + +jax_wheel( + name = "jax_cuda_pjrt_wheel", + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_plugin_wheel", +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index ced0b76c3..5b3ac6363 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -171,11 +171,12 @@ try: if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: + git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=args.jaxlib_git_hash, + git_hash=git_hash, ) finally: tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 0e2bba0c7..08c2389c2 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -167,11 +167,12 @@ try: if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: + git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=args.jaxlib_git_hash, + git_hash=git_hash, ) finally: if tmpdir: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 5ebdf6e4c..438cebca2 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -410,7 +410,8 @@ try: if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=args.jaxlib_git_hash) + git_hash = build_utils.get_githash(args.jaxlib_git_hash) + build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash) finally: if tmpdir: tmpdir.cleanup()