rocm_jax/jaxlib/jax.bzl
Daniel Suo 39e8ee93b0 Add experimental/serialize_executable.py to BUILD.
PiperOrigin-RevId: 736975882
2025-03-14 13:54:39 -07:00

572 lines
22 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bazel macros used by the JAX build."""
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION")
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFFIX")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library")
load("@rules_python//python:defs.bzl", "py_test")
load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource")
load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties")
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
# lint tools.
cc_proto_library = _cc_proto_library
cuda_library = _cuda_library
rocm_library = _rocm_library
pytype_test = native.py_test
nanobind_extension = _pybind_extension
if_cuda_is_configured = _if_cuda_is_configured
if_rocm_is_configured = _if_rocm_is_configured
if_windows = _if_windows
flatbuffer_cc_library = _flatbuffer_cc_library
tf_exec_properties = _tf_exec_properties
tf_cuda_tests_tags = _tf_cuda_tests_tags
jax_internal_packages = []
jax_extend_internal_users = []
mosaic_gpu_internal_users = []
mosaic_internal_users = []
pallas_gpu_internal_users = []
pallas_tpu_internal_users = []
pallas_fuser_users = []
mosaic_extension_deps = []
serialize_executable_internal_users = []
jax_internal_export_back_compat_test_util_visibility = []
jax_internal_test_harnesses_visibility = []
jax_test_util_visibility = []
loops_visibility = []
PLATFORM_TAGS_DICT = {
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
("Darwin", "x86_64"): ("macosx_11_0", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}
# TODO(vam): remove this once zstandard builds against Python 3.13
def get_zstandard():
if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft":
return []
return ["@pypi_zstandard//:pkg"]
_py_deps = {
"absl/logging": ["@pypi_absl_py//:pkg"],
"absl/testing": ["@pypi_absl_py//:pkg"],
"absl/flags": ["@pypi_absl_py//:pkg"],
"cloudpickle": ["@pypi_cloudpickle//:pkg"],
"colorama": ["@pypi_colorama//:pkg"],
"epath": ["@pypi_etils//:pkg"], # etils.epath
"filelock": ["@pypi_filelock//:pkg"],
"flatbuffers": ["@pypi_flatbuffers//:pkg"],
"hypothesis": ["@pypi_hypothesis//:pkg"],
"magma": [],
"matplotlib": ["@pypi_matplotlib//:pkg"],
"mpmath": [],
"opt_einsum": ["@pypi_opt_einsum//:pkg"],
"pil": ["@pypi_pillow//:pkg"],
"portpicker": ["@pypi_portpicker//:pkg"],
"ml_dtypes": ["@pypi_ml_dtypes//:pkg"],
"numpy": ["@pypi_numpy//:pkg"],
"scipy": ["@pypi_scipy//:pkg"],
"tensorflow_core": [],
"torch": [],
"zstandard": get_zstandard(),
}
def all_py_deps(excluded = []):
py_deps_copy = dict(_py_deps)
for excl in excluded:
py_deps_copy.pop(excl)
return py_deps(py_deps_copy.keys())
def py_deps(_package):
"""Returns the Bazel deps for Python package `package`."""
if type(_package) == type([]) or type(_package) == type(()):
deduped_py_deps = {}
for _pkg in _package:
for py_dep in _py_deps[_pkg]:
deduped_py_deps[py_dep] = _pkg
return deduped_py_deps.keys()
return _py_deps[_package]
def jax_visibility(_target):
"""Returns the additional Bazel visibilities for `target`."""
# This is only useful as part of a larger Bazel repository.
return []
jax_extra_deps = []
jax_gpu_support_deps = []
jax2tf_deps = []
def pytype_library(name, pytype_srcs = None, **kwargs):
_ = pytype_srcs # @unused
native.py_library(name = name, **kwargs)
def pytype_strict_library(name, pytype_srcs = [], **kwargs):
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
native.py_library(name = name, data = data, **new_kwargs)
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
lib_rule(name = name, data = data, **new_kwargs)
def py_extension(name, srcs, copts, deps, linkopts = []):
nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []):
"""Workaround DLL building issue.
1. cc_binary with linkshared enabled cannot produce DLL with symbol
correctly exported.
2. Even if the DLL is correctly built, the resulting target cannot be
correctly consumed by other targets.
Args:
name: the name of the output target
out: the name of the output DLL filename
deps: deps
srcs: srcs
"""
# create a dummy library to get the *.def file
dummy_library_name = name + ".dummy.dll"
native.cc_binary(
name = dummy_library_name,
linkshared = 1,
linkstatic = 1,
deps = deps,
target_compatible_with = ["@platforms//os:windows"],
)
# .def file with all symbols, not usable
full_def_name = name + ".full.def"
native.filegroup(
name = full_def_name,
srcs = [dummy_library_name],
output_group = "def_file",
target_compatible_with = ["@platforms//os:windows"],
)
# say filtered_symbol_prefixes == ["mlir", "chlo"], then construct the regex
# pattern as "^\\s*(mlir|clho)" to use grep
pattern = "^\\s*(" + "|".join(exported_symbol_prefixes) + ")"
# filtered def_file, only the needed symbols are included
filtered_def_name = name + ".filtered.def"
filtered_def_file = out + ".def"
native.genrule(
name = filtered_def_name,
srcs = [full_def_name],
outs = [filtered_def_file],
cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep -E '{}' $(location :{}) >> $@""".format(out, pattern, full_def_name),
target_compatible_with = ["@platforms//os:windows"],
)
# create the desired library
native.cc_binary(
name = out, # this name must be correct, it will be the filename
linkshared = 1,
deps = deps,
win_def_file = filtered_def_file,
target_compatible_with = ["@platforms//os:windows"],
)
# however, the created cc_library (a shared library) cannot be correctly
# consumed by other cc_*...
interface_library_file = out + ".if.lib"
native.filegroup(
name = interface_library_file,
srcs = [out],
output_group = "interface_library",
target_compatible_with = ["@platforms//os:windows"],
)
# but this one can be correctly consumed, this is our final product
native.cc_import(
name = name,
interface_library = interface_library_file,
shared_library = out,
target_compatible_with = ["@platforms//os:windows"],
)
ALL_BACKENDS = ["cpu", "gpu", "tpu"]
def if_building_jaxlib(
if_building,
if_not_building = [
"@pypi_jaxlib//:pkg",
"@pypi_jax_cuda12_plugin//:pkg",
"@pypi_jax_cuda12_pjrt//:pkg",
],
if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"],
if_py_import = [
"//jaxlib/tools:jaxlib_py_import",
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
],
if_py_import_for_cpu = [
"//jaxlib/tools:jaxlib_py_import",
]):
"""Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources.
This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase.
Args:
if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels
if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of
gpu-enabled builds
if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds
if_py_import: the py_import targets to depend on in case of gpu-enabled builds
if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds
"""
return select({
"//jax:enable_jaxlib_build": if_building,
"//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu,
"//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building,
"//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu,
"//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import,
})
# buildifier: disable=function-docstring
def jax_multiplatform_test(
name,
srcs,
args = [],
env = {},
shard_count = None,
deps = [],
data = [],
enable_backends = None,
backend_variant_args = {}, # buildifier: disable=unused-variable
backend_tags = {}, # buildifier: disable=unused-variable
disable_configs = None, # buildifier: disable=unused-variable
enable_configs = [],
config_tags_overrides = None, # buildifier: disable=unused-variable
tags = [],
main = None,
pjrt_c_api_bypass = False): # buildifier: disable=unused-variable
# enable_configs and disable_configs do not do anything in OSS, only in Google's CI.
# The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is
# as follows:
# 1. `enable_backends` is applied first, enabling all test configs for the given backends.
# 2. `disable_configs` is applied second, disabling the named test configs.
# 3. `enable_configs` is applied last, enabling the named test configs.
if main == None:
if len(srcs) == 1:
main = srcs[0]
else:
fail("Must set a main file to test multiple source files.")
for backend in ALL_BACKENDS:
if shard_count == None or type(shard_count) == type(0):
test_shards = shard_count
else:
test_shards = shard_count.get(backend, 1)
test_args = list(args) + [
"--jax_test_dut=" + backend,
"--jax_platform_name=" + backend,
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags.append("manual")
if backend == "gpu":
test_tags += tf_cuda_tests_tags()
native.py_test(
name = name + "_" + backend,
srcs = srcs,
args = test_args,
env = env,
deps = [
"//jax",
"//jax:test_util",
] + deps + if_building_jaxlib([
"//jaxlib/cuda:gpu_only_test_deps",
"//jaxlib/rocm:gpu_only_test_deps",
"//jax_plugins:gpu_plugin_only_test_deps",
]),
data = data,
shard_count = test_shards,
tags = test_tags,
main = main,
exec_properties = tf_exec_properties({"tags": test_tags}),
)
def jax_generate_backend_suites(backends = []):
"""Generates test suite targets named cpu_tests, gpu_tests, etc.
Args:
backends: the set of backends for which rules should be generated. Defaults to all backends.
"""
if not backends:
backends = ALL_BACKENDS
for backend in backends:
native.test_suite(
name = "%s_tests" % backend,
tags = ["jax_test_%s" % backend, "-manual"],
)
native.test_suite(
name = "backend_independent_tests",
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)
def _get_full_wheel_name(
package_name,
no_abi,
platform_independent,
platform_name,
cpu_name,
wheel_version,
py_freethreaded):
if no_abi or platform_independent:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}{free_threaded_suffix}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
return wheel_name_template.format(
package_name = package_name,
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = "any" if platform_independent else "_".join(
PLATFORM_TAGS_DICT[platform_name, cpu_name],
),
free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "",
)
def _get_source_distribution_name(package_name, wheel_version):
return "{package_name}-{wheel_version}.tar.gz".format(
package_name = package_name,
wheel_version = wheel_version,
)
def _jax_wheel_impl(ctx):
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value
executable = ctx.executable.wheel_binary
if include_cuda_libs and not override_include_cuda_libs:
fail("JAX wheel shouldn't be built directly against the CUDA libraries." +
" Please provide `--config=cuda_libraries_from_stubs` for bazel build command." +
" If you absolutely need to build links directly against the CUDA libraries, provide" +
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")
env = {}
args = ctx.actions.args()
full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
if not WHEEL_VERSION_SUFFIX:
env["JAX_RELEASE"] = "1"
cpu = ctx.attr.cpu
no_abi = ctx.attr.no_abi
platform_independent = ctx.attr.platform_independent
build_wheel_only = ctx.attr.build_wheel_only
editable = ctx.attr.editable
platform_name = ctx.attr.platform_name
if editable:
output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name)
wheel_dir = output_dir.path
outputs = [output_dir]
args.add("--editable")
else:
wheel_name = _get_full_wheel_name(
package_name = ctx.attr.wheel_name,
no_abi = no_abi,
platform_independent = platform_independent,
platform_name = platform_name,
cpu_name = cpu,
wheel_version = full_wheel_version,
py_freethreaded = py_freethreaded,
)
wheel_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")]
outputs = [wheel_file]
if not build_wheel_only:
source_distribution_name = _get_source_distribution_name(
package_name = ctx.attr.wheel_name,
wheel_version = full_wheel_version,
)
source_distribution_file = ctx.actions.declare_file(output_path +
"/" + source_distribution_name)
outputs.append(source_distribution_file)
args.add("--output_path", wheel_dir) # required argument
if not platform_independent:
args.add("--cpu", cpu)
args.add("--jaxlib_git_hash", git_hash) # required argument
if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
if ctx.attr.platform_version == "":
fail("platform_version must be set to a valid cuda version for cuda wheels")
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
if ctx.attr.enable_rocm:
args.add("--enable-rocm", "True")
if ctx.attr.platform_version == "":
fail("platform_version must be set to a valid rocm version for rocm wheels")
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
if ctx.attr.skip_gpu_kernels:
args.add("--skip_gpu_kernels")
srcs = []
for src in ctx.attr.source_files:
for f in src.files.to_list():
srcs.append(f)
args.add("--srcs=%s" % (f.path))
args.set_param_file_format("flag_per_line")
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = srcs,
outputs = outputs,
executable = executable,
env = env,
mnemonic = "BuildJaxWheel",
)
return [DefaultInfo(files = depset(direct = outputs))]
_jax_wheel = rule(
attrs = {
"wheel_binary": attr.label(
default = Label("//jaxlib/tools:build_wheel"),
executable = True,
# b/365588895 Investigate cfg = "exec" for multi platform builds
cfg = "target",
),
"wheel_name": attr.string(mandatory = True),
"no_abi": attr.bool(default = False),
"platform_independent": attr.bool(default = False),
"build_wheel_only": attr.bool(default = True),
"editable": attr.bool(default = False),
"cpu": attr.string(mandatory = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
"source_files": attr.label_list(allow_files = True),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
"py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")),
},
implementation = _jax_wheel_impl,
executable = False,
)
def jax_wheel(
name,
wheel_binary,
wheel_name,
no_abi = False,
platform_independent = False,
build_wheel_only = True,
editable = False,
enable_cuda = False,
enable_rocm = False,
platform_version = "",
source_files = []):
"""Create jax artifact wheels.
Common artifact attributes are grouped within a single macro.
Args:
name: the name of the wheel
wheel_binary: the binary to use to build the wheel
wheel_name: the name of the wheel
no_abi: whether to build a wheel without ABI
build_wheel_only: whether to build a wheel without source distribution
editable: whether to build an editable wheel
platform_independent: whether to build a wheel without platform tag
enable_cuda: whether to build a cuda wheel
enable_rocm: whether to build a rocm wheel
platform_version: the cuda version to use for the wheel
source_files: the source files to include in the wheel
Returns:
A directory containing the wheel
"""
_jax_wheel(
name = name,
wheel_binary = wheel_binary,
wheel_name = wheel_name,
no_abi = no_abi,
platform_independent = platform_independent,
build_wheel_only = build_wheel_only,
editable = editable,
enable_cuda = enable_cuda,
enable_rocm = enable_rocm,
platform_version = platform_version,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
# flag in bazel command to pass the git hash for nightly or release builds.
platform_name = select({
"@platforms//os:osx": "Darwin",
"@platforms//os:macos": "Darwin",
"@platforms//os:windows": "Windows",
"@platforms//os:linux": "Linux",
}),
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
cpu = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:macos_x86_64": "x86_64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:linux_aarch64": "aarch64",
"//jaxlib/tools:linux_x86_64": "x86_64",
}),
source_files = source_files,
)
jax_test_file_visibility = []
jax_export_file_visibility = []
def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable
pass
def jax_py_test(
name,
env = {},
**kwargs):
env = dict(env)
if "PYTHONWARNINGS" not in env:
env["PYTHONWARNINGS"] = "error"
py_test(name = name, env = env, **kwargs)