Add a jax_wheel Bazel rule to build jax pip packages

PiperOrigin-RevId: 689514531
This commit is contained in:
Kanglan Tang 2024-10-24 14:20:03 -07:00 committed by jax authors
parent 9500bd451a
commit af28595909
6 changed files with 187 additions and 23 deletions

View File

@ -62,6 +62,11 @@ def platform_tag(cpu: str) -> str:
}[(platform.system(), cpu)]
return f"{platform_name}_{cpu_name}"
def get_githash(jaxlib_git_hash):
if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash):
with open(jaxlib_git_hash, "r") as f:
return f.readline().strip()
return jaxlib_git_hash
def build_wheel(
sources_path: str, output_path: str, package_name: str, git_hash: str = ""

View File

@ -305,6 +305,95 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)
def _jax_wheel_impl(ctx):
executable = ctx.executable.wheel_binary
output = ctx.actions.declare_directory(ctx.label.name)
args = ctx.actions.args()
args.add("--output_path", output.path) # required argument
args.add("--cpu", ctx.attr.platform_tag) # required argument
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument
if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
if ctx.attr.platform_version == "":
fail("platform_version must be set to a valid cuda version for cuda wheels")
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
if ctx.attr.enable_rocm:
args.add("--enable-rocm", "True")
if ctx.attr.platform_version == "":
fail("platform_version must be set to a valid rocm version for rocm wheels")
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
if ctx.attr.skip_gpu_kernels:
args.add("--skip_gpu_kernels")
args.set_param_file_format("flag_per_line")
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
outputs = [output],
executable = executable,
)
return [DefaultInfo(files = depset(direct = [output]))]
_jax_wheel = rule(
attrs = {
"wheel_binary": attr.label(
default = Label("//jaxlib/tools:build_wheel"),
executable = True,
# b/365588895 Investigate cfg = "exec" for multi platform builds
cfg = "target",
),
"platform_tag": attr.string(mandatory = True),
"git_hash": attr.label(allow_single_file = True),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
},
implementation = _jax_wheel_impl,
executable = False,
)
def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
"""Create jax artifact wheels.
Common artifact attributes are grouped within a single macro.
Args:
name: the name of the wheel
wheel_binary: the binary to use to build the wheel
enable_cuda: whether to build a cuda wheel
platform_version: the cuda version to use for the wheel
Returns:
A directory containing the wheel
"""
_jax_wheel(
name = name,
wheel_binary = wheel_binary,
enable_cuda = enable_cuda,
platform_version = platform_version,
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
# the git hash file needs to be created first.
git_hash = select({
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
"//conditions:default": None,
}),
# Following the convention in jax/tools/build_utils.py.
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
platform_tag = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",
"@platforms//cpu:x86_64": "x86_64",
}),
)
jax_test_file_visibility = []
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable

View File

@ -14,9 +14,11 @@
# JAX is Autograd and XLA
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test")
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel")
licenses(["notice"]) # Apache 2
@ -30,11 +32,11 @@ py_binary(
"//jaxlib",
"//jaxlib:README.md",
"//jaxlib:setup.py",
"@xla//xla/ffi/api:api.h",
"@xla//xla/ffi/api:c_api.h",
"@xla//xla/ffi/api:ffi.h",
"@xla//xla/python:xla_client.py",
"@xla//xla/python:xla_extension",
"@xla//xla/ffi/api:c_api.h",
"@xla//xla/ffi/api:api.h",
"@xla//xla/ffi/api:ffi.h",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + if_cuda([
@ -47,8 +49,8 @@ py_binary(
"//jax/tools:build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_wheel//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)
@ -105,8 +107,8 @@ py_binary(
"//jax/tools:build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_wheel//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)
@ -134,7 +136,72 @@ py_binary(
"//jax/tools:build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_wheel//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)
selects.config_setting_group(
name = "macos",
match_any = [
"@platforms//os:osx",
"@platforms//os:macos",
],
)
selects.config_setting_group(
name = "arm64",
match_any = [
"@platforms//cpu:aarch64",
"@platforms//cpu:arm64",
],
)
selects.config_setting_group(
name = "macos_arm64",
match_all = [
":arm64",
":macos",
],
)
selects.config_setting_group(
name = "win_amd64",
match_all = [
"@platforms//cpu:x86_64",
"@platforms//os:windows",
],
)
string_flag(
name = "jaxlib_git_hash",
build_setting_default = "",
)
config_setting(
name = "jaxlib_git_hash_nightly_or_release",
flag_values = {
":jaxlib_git_hash": "nightly",
},
)
jax_wheel(
name = "jaxlib_wheel",
wheel_binary = ":build_wheel",
)
jax_wheel(
name = "jax_cuda_plugin_wheel",
enable_cuda = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_kernels_wheel",
)
jax_wheel(
name = "jax_cuda_pjrt_wheel",
enable_cuda = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_plugin_wheel",
)

View File

@ -171,11 +171,12 @@ try:
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
build_utils.build_wheel(
sources_path,
args.output_path,
package_name,
git_hash=args.jaxlib_git_hash,
git_hash=git_hash,
)
finally:
tmpdir.cleanup()

View File

@ -167,11 +167,12 @@ try:
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
build_utils.build_wheel(
sources_path,
args.output_path,
package_name,
git_hash=args.jaxlib_git_hash,
git_hash=git_hash,
)
finally:
if tmpdir:

View File

@ -410,7 +410,8 @@ try:
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=args.jaxlib_git_hash)
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash)
finally:
if tmpdir:
tmpdir.cleanup()