Merge pull request #15205 from yhtang:editable-jaxlib-build

PiperOrigin-RevId: 519704474
This commit is contained in:
jax authors 2023-03-27 06:33:31 -07:00
commit 6715736583
2 changed files with 23 additions and 1 deletions

View File

@ -439,6 +439,10 @@ def main():
default=None,
help="CPU platform to target. Default is the same as the host machine. "
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' jaxlib build instead of a wheel.")
add_boolean_argument(
parser,
"configure_only",
@ -549,6 +553,8 @@ def main():
[":build_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}"])
if args.editable:
command += ["--editable"]
print(" ".join(command))
shell(command)
shell([bazel_path] + args.bazel_startup_options + ["shutdown"])

View File

@ -48,6 +48,10 @@ parser.add_argument(
default=None,
required=True,
help="Target CPU architecture. Required.")
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' jaxlib build instead of a wheel.")
args = parser.parse_args()
r = runfiles.Create()
@ -282,6 +286,15 @@ def build_wheel(sources_path, output_path, cpu):
shutil.copy(wheel, output_path)
def build_editable(sources_path, output_path):
sys.stderr.write(
"To install the editable jaxlib build, run:\n\n"
f" pip install -e {output_path}\n\n"
)
shutil.rmtree(output_path, ignore_errors=True)
shutil.copytree(sources_path, output_path)
tmpdir = None
sources_path = args.sources_path
if sources_path is None:
@ -291,7 +304,10 @@ if sources_path is None:
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_wheel(sources_path)
build_wheel(sources_path, args.output_path, args.cpu)
if args.editable:
build_editable(sources_path, args.output_path)
else:
build_wheel(sources_path, args.output_path, args.cpu)
finally:
if tmpdir:
tmpdir.cleanup()