From 4eb782e4027a463b64dea08583dd0c757cf761bc Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Feb 2025 07:39:15 -0800 Subject: [PATCH] Update `jax_wheel` target to produce both wheel and source distribution files. This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files. PiperOrigin-RevId: 731721522 --- BUILD.bazel | 1 + build_wheel.py | 1 + jaxlib/jax.bzl | 30 +++++++++++++++++++++++++----- jaxlib/tools/build_utils.py | 21 ++++++++++++++++++--- 4 files changed, 45 insertions(+), 8 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index ec6c87166..617e39e73 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -70,6 +70,7 @@ py_binary( jax_wheel( name = "jax_wheel", + build_wheel_only = False, platform_independent = True, source_files = [ ":transitive_py_data", diff --git a/build_wheel.py b/build_wheel.py index ec93b4155..f8e1595d3 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -94,6 +94,7 @@ try: args.output_path, package_name="jax", git_hash=args.jaxlib_git_hash, + build_wheel_only=False, ) finally: if tmpdir: diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 19f8c4258..14f92058d 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -340,6 +340,12 @@ def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_na ), ) +def _get_source_distribution_name(package_name, wheel_version): + return "{package_name}-{wheel_version}.tar.gz".format( + package_name = package_name, + wheel_version = wheel_version, + ) + def _jax_wheel_impl(ctx): include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value @@ -367,6 +373,7 @@ def _jax_wheel_impl(ctx): cpu = ctx.attr.cpu no_abi = ctx.attr.no_abi platform_independent = ctx.attr.platform_independent + build_wheel_only = ctx.attr.build_wheel_only platform_name = ctx.attr.platform_name wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, @@ -376,9 +383,18 @@ def _jax_wheel_impl(ctx): cpu_name = cpu, wheel_version = full_wheel_version, ) - output_file = ctx.actions.declare_file(output_path + - "/" + wheel_name) - wheel_dir = output_file.path[:output_file.path.rfind("/")] + wheel_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] + outputs = [wheel_file] + if not build_wheel_only: + source_distribution_name = _get_source_distribution_name( + package_name = ctx.attr.wheel_name, + wheel_version = full_wheel_version, + ) + source_distribution_file = ctx.actions.declare_file(output_path + + "/" + source_distribution_name) + outputs.append(source_distribution_file) args.add("--output_path", wheel_dir) # required argument if not platform_independent: @@ -409,13 +425,13 @@ def _jax_wheel_impl(ctx): ctx.actions.run( arguments = [args], inputs = srcs, - outputs = [output_file], + outputs = outputs, executable = executable, env = env, mnemonic = "BuildJaxWheel", ) - return [DefaultInfo(files = depset(direct = [output_file]))] + return [DefaultInfo(files = depset(direct = outputs))] _jax_wheel = rule( attrs = { @@ -428,6 +444,7 @@ _jax_wheel = rule( "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), "platform_independent": attr.bool(default = False), + "build_wheel_only": attr.bool(default = True), "cpu": attr.string(mandatory = True), "platform_name": attr.string(mandatory = True), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), @@ -451,6 +468,7 @@ def jax_wheel( wheel_name, no_abi = False, platform_independent = False, + build_wheel_only = True, enable_cuda = False, platform_version = "", source_files = []): @@ -463,6 +481,7 @@ def jax_wheel( wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI + build_wheel_only: whether to build a wheel without source distribution platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel platform_version: the cuda version to use for the wheel @@ -477,6 +496,7 @@ def jax_wheel( wheel_name = wheel_name, no_abi = no_abi, platform_independent = platform_independent, + build_wheel_only = build_wheel_only, enable_cuda = enable_cuda, platform_version = platform_version, # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 0db7c7072..9c7f61fc2 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -60,14 +60,23 @@ def platform_tag(cpu: str) -> str: def build_wheel( - sources_path: str, output_path: str, package_name: str, git_hash: str = "" + sources_path: str, + output_path: str, + package_name: str, + git_hash: str = "", + build_wheel_only: bool = True, ) -> None: """Builds a wheel in `output_path` using the source tree in `sources_path`.""" env = dict(os.environ) if git_hash: env["JAX_GIT_HASH"] = git_hash - subprocess.run([sys.executable, "-m", "build", "-n", "-w"], - check=True, cwd=sources_path, env=env) + subprocess.run( + [sys.executable, "-m", "build", "-n"] + + (["-w"] if build_wheel_only else []), + check=True, + cwd=sources_path, + env=env, + ) for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")): output_file = os.path.join(output_path, os.path.basename(wheel)) sys.stderr.write(f"Output wheel: {output_file}\n\n") @@ -82,6 +91,12 @@ def build_wheel( sys.stderr.write(" bazel run //build:requirements.update" + f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n") shutil.copy(wheel, output_path) + if not build_wheel_only: + for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")): + output_file = os.path.join(output_path, os.path.basename(dist)) + sys.stderr.write(f"Output source distribution: {output_file}\n\n") + shutil.copy(dist, output_path) + def build_editable( sources_path: str, output_path: str, package_name: str