mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a jax_wheel Bazel rule to build jax pip packages
PiperOrigin-RevId: 689514531
This commit is contained in:
parent
9500bd451a
commit
af28595909
@ -62,6 +62,11 @@ def platform_tag(cpu: str) -> str:
|
||||
}[(platform.system(), cpu)]
|
||||
return f"{platform_name}_{cpu_name}"
|
||||
|
||||
def get_githash(jaxlib_git_hash):
|
||||
if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash):
|
||||
with open(jaxlib_git_hash, "r") as f:
|
||||
return f.readline().strip()
|
||||
return jaxlib_git_hash
|
||||
|
||||
def build_wheel(
|
||||
sources_path: str, output_path: str, package_name: str, git_hash: str = ""
|
||||
|
@ -305,6 +305,95 @@ def jax_generate_backend_suites(backends = []):
|
||||
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
|
||||
)
|
||||
|
||||
def _jax_wheel_impl(ctx):
|
||||
executable = ctx.executable.wheel_binary
|
||||
|
||||
output = ctx.actions.declare_directory(ctx.label.name)
|
||||
args = ctx.actions.args()
|
||||
args.add("--output_path", output.path) # required argument
|
||||
args.add("--cpu", ctx.attr.platform_tag) # required argument
|
||||
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
|
||||
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument
|
||||
|
||||
if ctx.attr.enable_cuda:
|
||||
args.add("--enable-cuda", "True")
|
||||
if ctx.attr.platform_version == "":
|
||||
fail("platform_version must be set to a valid cuda version for cuda wheels")
|
||||
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
|
||||
if ctx.attr.enable_rocm:
|
||||
args.add("--enable-rocm", "True")
|
||||
if ctx.attr.platform_version == "":
|
||||
fail("platform_version must be set to a valid rocm version for rocm wheels")
|
||||
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
|
||||
if ctx.attr.skip_gpu_kernels:
|
||||
args.add("--skip_gpu_kernels")
|
||||
|
||||
args.set_param_file_format("flag_per_line")
|
||||
args.use_param_file("@%s", use_always = False)
|
||||
ctx.actions.run(
|
||||
arguments = [args],
|
||||
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
|
||||
outputs = [output],
|
||||
executable = executable,
|
||||
)
|
||||
return [DefaultInfo(files = depset(direct = [output]))]
|
||||
|
||||
_jax_wheel = rule(
|
||||
attrs = {
|
||||
"wheel_binary": attr.label(
|
||||
default = Label("//jaxlib/tools:build_wheel"),
|
||||
executable = True,
|
||||
# b/365588895 Investigate cfg = "exec" for multi platform builds
|
||||
cfg = "target",
|
||||
),
|
||||
"platform_tag": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(allow_single_file = True),
|
||||
"enable_cuda": attr.bool(default = False),
|
||||
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
|
||||
"platform_version": attr.string(mandatory = True, default = ""),
|
||||
"skip_gpu_kernels": attr.bool(default = False),
|
||||
"enable_rocm": attr.bool(default = False),
|
||||
},
|
||||
implementation = _jax_wheel_impl,
|
||||
executable = False,
|
||||
)
|
||||
|
||||
def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
"""Create jax artifact wheels.
|
||||
|
||||
Common artifact attributes are grouped within a single macro.
|
||||
|
||||
Args:
|
||||
name: the name of the wheel
|
||||
wheel_binary: the binary to use to build the wheel
|
||||
enable_cuda: whether to build a cuda wheel
|
||||
platform_version: the cuda version to use for the wheel
|
||||
|
||||
Returns:
|
||||
A directory containing the wheel
|
||||
"""
|
||||
_jax_wheel(
|
||||
name = name,
|
||||
wheel_binary = wheel_binary,
|
||||
enable_cuda = enable_cuda,
|
||||
platform_version = platform_version,
|
||||
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
|
||||
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
|
||||
# the git hash file needs to be created first.
|
||||
git_hash = select({
|
||||
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
|
||||
"//conditions:default": None,
|
||||
}),
|
||||
# Following the convention in jax/tools/build_utils.py.
|
||||
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
|
||||
platform_tag = select({
|
||||
"//jaxlib/tools:macos_arm64": "arm64",
|
||||
"//jaxlib/tools:win_amd64": "AMD64",
|
||||
"//jaxlib/tools:arm64": "aarch64",
|
||||
"@platforms//cpu:x86_64": "x86_64",
|
||||
}),
|
||||
)
|
||||
|
||||
jax_test_file_visibility = []
|
||||
|
||||
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
|
||||
|
@ -14,9 +14,11 @@
|
||||
|
||||
# JAX is Autograd and XLA
|
||||
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test")
|
||||
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel")
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
@ -30,11 +32,11 @@ py_binary(
|
||||
"//jaxlib",
|
||||
"//jaxlib:README.md",
|
||||
"//jaxlib:setup.py",
|
||||
"@xla//xla/ffi/api:api.h",
|
||||
"@xla//xla/ffi/api:c_api.h",
|
||||
"@xla//xla/ffi/api:ffi.h",
|
||||
"@xla//xla/python:xla_client.py",
|
||||
"@xla//xla/python:xla_extension",
|
||||
"@xla//xla/ffi/api:c_api.h",
|
||||
"@xla//xla/ffi/api:api.h",
|
||||
"@xla//xla/ffi/api:ffi.h",
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]) + if_cuda([
|
||||
@ -44,11 +46,11 @@ py_binary(
|
||||
"//jaxlib/rocm:rocm_gpu_support",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_wheel//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
"@pypi_wheel//:pkg",
|
||||
],
|
||||
)
|
||||
|
||||
@ -57,7 +59,7 @@ jax_py_test(
|
||||
srcs = ["build_wheel_test.py"],
|
||||
data = [":build_wheel"],
|
||||
deps = [
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
],
|
||||
)
|
||||
|
||||
@ -102,11 +104,11 @@ py_binary(
|
||||
"//jax_plugins/rocm:__init__.py",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_wheel//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
"@pypi_wheel//:pkg",
|
||||
],
|
||||
)
|
||||
|
||||
@ -131,10 +133,75 @@ py_binary(
|
||||
"//jax_plugins/rocm:plugin_setup.py",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_wheel//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
"//jax/tools:build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
"@pypi_wheel//:pkg",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "macos",
|
||||
match_any = [
|
||||
"@platforms//os:osx",
|
||||
"@platforms//os:macos",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "arm64",
|
||||
match_any = [
|
||||
"@platforms//cpu:aarch64",
|
||||
"@platforms//cpu:arm64",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "macos_arm64",
|
||||
match_all = [
|
||||
":arm64",
|
||||
":macos",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "win_amd64",
|
||||
match_all = [
|
||||
"@platforms//cpu:x86_64",
|
||||
"@platforms//os:windows",
|
||||
],
|
||||
)
|
||||
|
||||
string_flag(
|
||||
name = "jaxlib_git_hash",
|
||||
build_setting_default = "",
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "jaxlib_git_hash_nightly_or_release",
|
||||
flag_values = {
|
||||
":jaxlib_git_hash": "nightly",
|
||||
},
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jaxlib_wheel",
|
||||
wheel_binary = ":build_wheel",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_plugin_wheel",
|
||||
enable_cuda = True,
|
||||
# TODO(b/371217563) May use hermetic cuda version here.
|
||||
platform_version = "12",
|
||||
wheel_binary = ":build_gpu_kernels_wheel",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_pjrt_wheel",
|
||||
enable_cuda = True,
|
||||
# TODO(b/371217563) May use hermetic cuda version here.
|
||||
platform_version = "12",
|
||||
wheel_binary = ":build_gpu_plugin_wheel",
|
||||
)
|
||||
|
@ -171,11 +171,12 @@ try:
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
|
||||
build_utils.build_wheel(
|
||||
sources_path,
|
||||
args.output_path,
|
||||
package_name,
|
||||
git_hash=args.jaxlib_git_hash,
|
||||
git_hash=git_hash,
|
||||
)
|
||||
finally:
|
||||
tmpdir.cleanup()
|
||||
|
@ -167,11 +167,12 @@ try:
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
|
||||
build_utils.build_wheel(
|
||||
sources_path,
|
||||
args.output_path,
|
||||
package_name,
|
||||
git_hash=args.jaxlib_git_hash,
|
||||
git_hash=git_hash,
|
||||
)
|
||||
finally:
|
||||
if tmpdir:
|
||||
|
@ -410,7 +410,8 @@ try:
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=args.jaxlib_git_hash)
|
||||
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
||||
|
Loading…
x
Reference in New Issue
Block a user