Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies.

PiperOrigin-RevId: 595529249
This commit is contained in:
Parker Schuh 2024-01-03 16:11:44 -08:00 committed by jax authors
parent f1be301049
commit 23b9c2a22f
4 changed files with 37 additions and 3 deletions

View File

@ -93,6 +93,15 @@ def check_numpy_version(python_bin_path):
sys.exit(-1)
return version
def get_githash():
try:
return subprocess.run(
["git", "rev-parse", "HEAD"],
encoding='utf-8',
capture_output=True).stdout.strip()
except OSError:
return ""
# Bazel
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.1.2/"
@ -583,6 +592,7 @@ def main():
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_wheel", "--",
f"--output_path={output_path}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}"])
if args.build_gpu_plugin:
command.append("--include_gpu_plugin_extension")
@ -596,6 +606,7 @@ def main():
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_cuda_kernels_wheel", "--",
f"--output_path={output_path}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:
@ -608,6 +619,7 @@ def main():
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_gpu_plugin_wheel", "--",
f"--output_path={output_path}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:

View File

@ -63,10 +63,15 @@ def platform_tag(cpu: str) -> str:
return f"{platform_name}_{cpu_name}"
def build_wheel(sources_path: str, output_path: str, package_name: str) -> None:
def build_wheel(
sources_path: str, output_path: str, package_name: str, git_hash: str = ""
) -> None:
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
env = dict(os.environ)
if git_hash:
env["JAX_GIT_HASH"] = git_hash
subprocess.run([sys.executable, "-m", "build", "-n", "-w"],
check=True, cwd=sources_path)
check=True, cwd=sources_path, env=env)
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
output_file = os.path.join(output_path, os.path.basename(wheel))
sys.stderr.write(f"Output wheel: {output_file}\n\n")

View File

@ -26,6 +26,9 @@ _version = "0.4.24"
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_git_hash: str | None = None
def _get_version_string() -> str:
# The build/source distribution for jax & jaxlib overwrites _release_version.
@ -94,6 +97,14 @@ def _write_version(fname: str) -> None:
if contents.count(old_version_string) != 2:
raise RuntimeError(f"Build: could not find {old_version_string!r} in {fname}")
contents = contents.replace(old_version_string, new_version_string)
githash = os.environ.get("JAX_GIT_HASH")
if githash:
old_githash_string = "_git_hash: str | None = None"
new_githash_string = f"_git_hash: str = {githash!r}"
if contents.count(old_githash_string) != 2:
raise RuntimeError(f"Build: could not find {old_githash_string!r} in {fname}")
contents = contents.replace(old_githash_string, new_githash_string)
fhandle.write_text(contents)

View File

@ -42,6 +42,12 @@ parser.add_argument(
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--jaxlib_git_hash",
default="",
required=True,
help="Git hash. Empty if unknown. Optional.",
)
parser.add_argument(
"--cpu", default=None, required=True, help="Target CPU architecture. Required."
)
@ -344,7 +350,7 @@ try:
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
build_utils.build_wheel(sources_path, args.output_path, package_name)
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=args.jaxlib_git_hash)
finally:
if tmpdir:
tmpdir.cleanup()