mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Update jax_wheel
target to produce both wheel and source distribution files.
This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files. PiperOrigin-RevId: 731721522
This commit is contained in:
parent
a8738a069e
commit
4eb782e402
@ -70,6 +70,7 @@ py_binary(
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_wheel",
|
||||
build_wheel_only = False,
|
||||
platform_independent = True,
|
||||
source_files = [
|
||||
":transitive_py_data",
|
||||
|
@ -94,6 +94,7 @@ try:
|
||||
args.output_path,
|
||||
package_name="jax",
|
||||
git_hash=args.jaxlib_git_hash,
|
||||
build_wheel_only=False,
|
||||
)
|
||||
finally:
|
||||
if tmpdir:
|
||||
|
@ -340,6 +340,12 @@ def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_na
|
||||
),
|
||||
)
|
||||
|
||||
def _get_source_distribution_name(package_name, wheel_version):
|
||||
return "{package_name}-{wheel_version}.tar.gz".format(
|
||||
package_name = package_name,
|
||||
wheel_version = wheel_version,
|
||||
)
|
||||
|
||||
def _jax_wheel_impl(ctx):
|
||||
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
|
||||
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
|
||||
@ -367,6 +373,7 @@ def _jax_wheel_impl(ctx):
|
||||
cpu = ctx.attr.cpu
|
||||
no_abi = ctx.attr.no_abi
|
||||
platform_independent = ctx.attr.platform_independent
|
||||
build_wheel_only = ctx.attr.build_wheel_only
|
||||
platform_name = ctx.attr.platform_name
|
||||
wheel_name = _get_full_wheel_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
@ -376,9 +383,18 @@ def _jax_wheel_impl(ctx):
|
||||
cpu_name = cpu,
|
||||
wheel_version = full_wheel_version,
|
||||
)
|
||||
output_file = ctx.actions.declare_file(output_path +
|
||||
"/" + wheel_name)
|
||||
wheel_dir = output_file.path[:output_file.path.rfind("/")]
|
||||
wheel_file = ctx.actions.declare_file(output_path +
|
||||
"/" + wheel_name)
|
||||
wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")]
|
||||
outputs = [wheel_file]
|
||||
if not build_wheel_only:
|
||||
source_distribution_name = _get_source_distribution_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
wheel_version = full_wheel_version,
|
||||
)
|
||||
source_distribution_file = ctx.actions.declare_file(output_path +
|
||||
"/" + source_distribution_name)
|
||||
outputs.append(source_distribution_file)
|
||||
|
||||
args.add("--output_path", wheel_dir) # required argument
|
||||
if not platform_independent:
|
||||
@ -409,13 +425,13 @@ def _jax_wheel_impl(ctx):
|
||||
ctx.actions.run(
|
||||
arguments = [args],
|
||||
inputs = srcs,
|
||||
outputs = [output_file],
|
||||
outputs = outputs,
|
||||
executable = executable,
|
||||
env = env,
|
||||
mnemonic = "BuildJaxWheel",
|
||||
)
|
||||
|
||||
return [DefaultInfo(files = depset(direct = [output_file]))]
|
||||
return [DefaultInfo(files = depset(direct = outputs))]
|
||||
|
||||
_jax_wheel = rule(
|
||||
attrs = {
|
||||
@ -428,6 +444,7 @@ _jax_wheel = rule(
|
||||
"wheel_name": attr.string(mandatory = True),
|
||||
"no_abi": attr.bool(default = False),
|
||||
"platform_independent": attr.bool(default = False),
|
||||
"build_wheel_only": attr.bool(default = True),
|
||||
"cpu": attr.string(mandatory = True),
|
||||
"platform_name": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
|
||||
@ -451,6 +468,7 @@ def jax_wheel(
|
||||
wheel_name,
|
||||
no_abi = False,
|
||||
platform_independent = False,
|
||||
build_wheel_only = True,
|
||||
enable_cuda = False,
|
||||
platform_version = "",
|
||||
source_files = []):
|
||||
@ -463,6 +481,7 @@ def jax_wheel(
|
||||
wheel_binary: the binary to use to build the wheel
|
||||
wheel_name: the name of the wheel
|
||||
no_abi: whether to build a wheel without ABI
|
||||
build_wheel_only: whether to build a wheel without source distribution
|
||||
platform_independent: whether to build a wheel without platform tag
|
||||
enable_cuda: whether to build a cuda wheel
|
||||
platform_version: the cuda version to use for the wheel
|
||||
@ -477,6 +496,7 @@ def jax_wheel(
|
||||
wheel_name = wheel_name,
|
||||
no_abi = no_abi,
|
||||
platform_independent = platform_independent,
|
||||
build_wheel_only = build_wheel_only,
|
||||
enable_cuda = enable_cuda,
|
||||
platform_version = platform_version,
|
||||
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
|
||||
|
@ -60,14 +60,23 @@ def platform_tag(cpu: str) -> str:
|
||||
|
||||
|
||||
def build_wheel(
|
||||
sources_path: str, output_path: str, package_name: str, git_hash: str = ""
|
||||
sources_path: str,
|
||||
output_path: str,
|
||||
package_name: str,
|
||||
git_hash: str = "",
|
||||
build_wheel_only: bool = True,
|
||||
) -> None:
|
||||
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
||||
env = dict(os.environ)
|
||||
if git_hash:
|
||||
env["JAX_GIT_HASH"] = git_hash
|
||||
subprocess.run([sys.executable, "-m", "build", "-n", "-w"],
|
||||
check=True, cwd=sources_path, env=env)
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "build", "-n"]
|
||||
+ (["-w"] if build_wheel_only else []),
|
||||
check=True,
|
||||
cwd=sources_path,
|
||||
env=env,
|
||||
)
|
||||
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
|
||||
output_file = os.path.join(output_path, os.path.basename(wheel))
|
||||
sys.stderr.write(f"Output wheel: {output_file}\n\n")
|
||||
@ -82,6 +91,12 @@ def build_wheel(
|
||||
sys.stderr.write(" bazel run //build:requirements.update" +
|
||||
f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n")
|
||||
shutil.copy(wheel, output_path)
|
||||
if not build_wheel_only:
|
||||
for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")):
|
||||
output_file = os.path.join(output_path, os.path.basename(dist))
|
||||
sys.stderr.write(f"Output source distribution: {output_file}\n\n")
|
||||
shutil.copy(dist, output_path)
|
||||
|
||||
|
||||
def build_editable(
|
||||
sources_path: str, output_path: str, package_name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user