diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 9c7f61fc2..4c50cff16 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -70,6 +70,12 @@ def build_wheel( env = dict(os.environ) if git_hash: env["JAX_GIT_HASH"] = git_hash + if is_windows() and ( + "USERPROFILE" not in env + and "HOMEDRIVE" not in env + and "HOMEPATH" not in env + ): + env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] + (["-w"] if build_wheel_only else []),