mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files. PiperOrigin-RevId: 731721522
131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for the building JAX related python packages."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import pathlib
|
|
import platform
|
|
import shutil
|
|
import sys
|
|
import subprocess
|
|
import glob
|
|
from collections.abc import Sequence
|
|
from jaxlib.tools import platform_tags
|
|
|
|
|
|
def is_windows() -> bool:
|
|
return sys.platform.startswith("win32")
|
|
|
|
|
|
def copy_file(
|
|
src_files: str | Sequence[str],
|
|
dst_dir: pathlib.Path,
|
|
dst_filename = None,
|
|
runfiles = None,
|
|
) -> None:
|
|
dst_dir.mkdir(parents=True, exist_ok=True)
|
|
if isinstance(src_files, str):
|
|
src_files = [src_files]
|
|
for src_file in src_files:
|
|
src_file_rloc = runfiles.Rlocation(src_file)
|
|
if src_file_rloc is None:
|
|
raise ValueError(f"Unable to find wheel source file {src_file}")
|
|
src_filename = os.path.basename(src_file_rloc)
|
|
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
|
if is_windows():
|
|
shutil.copyfile(src_file_rloc, dst_file)
|
|
else:
|
|
shutil.copy(src_file_rloc, dst_file)
|
|
|
|
|
|
def platform_tag(cpu: str) -> str:
|
|
platform_name, cpu_name = platform_tags.PLATFORM_TAGS_DICT[
|
|
(platform.system(), cpu)
|
|
]
|
|
return f"{platform_name}_{cpu_name}"
|
|
|
|
|
|
def build_wheel(
|
|
sources_path: str,
|
|
output_path: str,
|
|
package_name: str,
|
|
git_hash: str = "",
|
|
build_wheel_only: bool = True,
|
|
) -> None:
|
|
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
|
env = dict(os.environ)
|
|
if git_hash:
|
|
env["JAX_GIT_HASH"] = git_hash
|
|
subprocess.run(
|
|
[sys.executable, "-m", "build", "-n"]
|
|
+ (["-w"] if build_wheel_only else []),
|
|
check=True,
|
|
cwd=sources_path,
|
|
env=env,
|
|
)
|
|
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
|
|
output_file = os.path.join(output_path, os.path.basename(wheel))
|
|
sys.stderr.write(f"Output wheel: {output_file}\n\n")
|
|
sys.stderr.write(f"To install the newly-built {package_name} wheel " +
|
|
"on system Python, run:\n")
|
|
sys.stderr.write(f" pip install {output_file} --force-reinstall\n\n")
|
|
|
|
py_version = ".".join(platform.python_version_tuple()[:-1])
|
|
sys.stderr.write(f"To install the newly-built {package_name} wheel " +
|
|
"on hermetic Python, run:\n")
|
|
sys.stderr.write(f' echo -e "\\n{output_file}" >> build/requirements.in\n')
|
|
sys.stderr.write(" bazel run //build:requirements.update" +
|
|
f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n")
|
|
shutil.copy(wheel, output_path)
|
|
if not build_wheel_only:
|
|
for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")):
|
|
output_file = os.path.join(output_path, os.path.basename(dist))
|
|
sys.stderr.write(f"Output source distribution: {output_file}\n\n")
|
|
shutil.copy(dist, output_path)
|
|
|
|
|
|
def build_editable(
|
|
sources_path: str, output_path: str, package_name: str
|
|
) -> None:
|
|
sys.stderr.write(
|
|
f"To install the editable {package_name} build, run:\n\n"
|
|
f" pip install -e {output_path}\n\n"
|
|
)
|
|
shutil.rmtree(output_path, ignore_errors=True)
|
|
shutil.copytree(sources_path, output_path)
|
|
|
|
|
|
def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str):
|
|
src_file = file_dir / "setup.py"
|
|
with open(src_file) as f:
|
|
content = f.read()
|
|
content = content.replace(
|
|
"cuda_version = 0 # placeholder", f"cuda_version = {cuda_version}"
|
|
)
|
|
with open(src_file, "w") as f:
|
|
f.write(content)
|
|
|
|
def update_setup_with_rocm_version(file_dir: pathlib.Path, rocm_version: str):
|
|
src_file = file_dir / "setup.py"
|
|
with open(src_file) as f:
|
|
content = f.read()
|
|
content = content.replace(
|
|
"rocm_version = 0 # placeholder", f"rocm_version = {rocm_version}"
|
|
)
|
|
with open(src_file, "w") as f:
|
|
f.write(content)
|