mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.
PiperOrigin-RevId: 595529249
This commit is contained in:
parent
f1be301049
commit
23b9c2a22f
@ -93,6 +93,15 @@ def check_numpy_version(python_bin_path):
|
||||
sys.exit(-1)
|
||||
return version
|
||||
|
||||
def get_githash():
|
||||
try:
|
||||
return subprocess.run(
|
||||
["git", "rev-parse", "HEAD"],
|
||||
encoding='utf-8',
|
||||
capture_output=True).stdout.strip()
|
||||
except OSError:
|
||||
return ""
|
||||
|
||||
# Bazel
|
||||
|
||||
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.1.2/"
|
||||
@ -583,6 +592,7 @@ def main():
|
||||
["run", "--verbose_failures=true"] +
|
||||
["//jaxlib/tools:build_wheel", "--",
|
||||
f"--output_path={output_path}",
|
||||
f"--jaxlib_git_hash={get_githash()}",
|
||||
f"--cpu={wheel_cpu}"])
|
||||
if args.build_gpu_plugin:
|
||||
command.append("--include_gpu_plugin_extension")
|
||||
@ -596,6 +606,7 @@ def main():
|
||||
["run", "--verbose_failures=true"] +
|
||||
["//jaxlib/tools:build_cuda_kernels_wheel", "--",
|
||||
f"--output_path={output_path}",
|
||||
f"--jaxlib_git_hash={get_githash()}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
f"--cuda_version={args.gpu_plugin_cuda_version}"])
|
||||
if args.editable:
|
||||
@ -608,6 +619,7 @@ def main():
|
||||
["run", "--verbose_failures=true"] +
|
||||
["//jaxlib/tools:build_gpu_plugin_wheel", "--",
|
||||
f"--output_path={output_path}",
|
||||
f"--jaxlib_git_hash={get_githash()}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
f"--cuda_version={args.gpu_plugin_cuda_version}"])
|
||||
if args.editable:
|
||||
|
@ -63,10 +63,15 @@ def platform_tag(cpu: str) -> str:
|
||||
return f"{platform_name}_{cpu_name}"
|
||||
|
||||
|
||||
def build_wheel(sources_path: str, output_path: str, package_name: str) -> None:
|
||||
def build_wheel(
|
||||
sources_path: str, output_path: str, package_name: str, git_hash: str = ""
|
||||
) -> None:
|
||||
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
||||
env = dict(os.environ)
|
||||
if git_hash:
|
||||
env["JAX_GIT_HASH"] = git_hash
|
||||
subprocess.run([sys.executable, "-m", "build", "-n", "-w"],
|
||||
check=True, cwd=sources_path)
|
||||
check=True, cwd=sources_path, env=env)
|
||||
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
|
||||
output_file = os.path.join(output_path, os.path.basename(wheel))
|
||||
sys.stderr.write(f"Output wheel: {output_file}\n\n")
|
||||
|
@ -26,6 +26,9 @@ _version = "0.4.24"
|
||||
# releases. Do not modify this manually, or jax/jaxlib build will fail.
|
||||
_release_version: str | None = None
|
||||
|
||||
# The following line is overwritten by build scripts in distributions &
|
||||
# releases. Do not modify this manually, or jax/jaxlib build will fail.
|
||||
_git_hash: str | None = None
|
||||
|
||||
def _get_version_string() -> str:
|
||||
# The build/source distribution for jax & jaxlib overwrites _release_version.
|
||||
@ -94,6 +97,14 @@ def _write_version(fname: str) -> None:
|
||||
if contents.count(old_version_string) != 2:
|
||||
raise RuntimeError(f"Build: could not find {old_version_string!r} in {fname}")
|
||||
contents = contents.replace(old_version_string, new_version_string)
|
||||
|
||||
githash = os.environ.get("JAX_GIT_HASH")
|
||||
if githash:
|
||||
old_githash_string = "_git_hash: str | None = None"
|
||||
new_githash_string = f"_git_hash: str = {githash!r}"
|
||||
if contents.count(old_githash_string) != 2:
|
||||
raise RuntimeError(f"Build: could not find {old_githash_string!r} in {fname}")
|
||||
contents = contents.replace(old_githash_string, new_githash_string)
|
||||
fhandle.write_text(contents)
|
||||
|
||||
|
||||
|
@ -42,6 +42,12 @@ parser.add_argument(
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jaxlib_git_hash",
|
||||
default="",
|
||||
required=True,
|
||||
help="Git hash. Empty if unknown. Optional.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
|
||||
)
|
||||
@ -344,7 +350,7 @@ try:
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name)
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=args.jaxlib_git_hash)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
||||
|
Loading…
x
Reference in New Issue
Block a user