Update XLA version to include CUDA version check fix and newer XLA.

Make specification of CUDA environment variables optional.

Fixes #627.
Fixes #276, although the fix requires a new Jaxlib release.
This commit is contained in:
Peter Hawkins 2019-04-25 16:19:53 -07:00
parent b2160fdc03
commit f60d927df8
2 changed files with 17 additions and 10 deletions

View File

@ -23,10 +23,10 @@ http_archive(
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "6248e6c47b18ab0f496edfa88cac58208189b8b26c7efd2f37ce6a4433159da9",
strip_prefix = "tensorflow-0e6271d916cbca4e12d61c94dae0230e3a47cd39",
sha256 = "732ccdb272a275014c28e757c30a4575dc100865a8cb30ffc4c39b7d4bc761a2",
strip_prefix = "tensorflow-afeb3a07811b201f53e365807ac5173e4bad1c2e",
urls = [
"https://github.com/tensorflow/tensorflow/archive/0e6271d916cbca4e12d61c94dae0230e3a47cd39.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/afeb3a07811b201f53e365807ac5173e4bad1c2e.tar.gz",
],
)

View File

@ -161,8 +161,6 @@ BAZELRC_TEMPLATE = """
build --action_env PYTHON_BIN_PATH="{python_bin_path}"
build --python_path="{python_bin_path}"
build --action_env TF_NEED_CUDA="{tf_need_cuda}"
build --action_env CUDA_TOOLKIT_PATH="{cuda_toolkit_path}"
build --action_env CUDNN_INSTALL_PATH="{cudnn_install_path}"
build --distinct_host_configuration=false
build --copt=-Wno-sign-compare
build -c opt
@ -186,9 +184,16 @@ build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
"""
def write_bazelrc(**kwargs):
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
f = open("../.bazelrc", "w")
f.write(BAZELRC_TEMPLATE.format(**kwargs))
if cuda_toolkit_path:
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
.format(cuda_toolkit_path=cuda_toolkit_path))
if cudnn_install_path:
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
.format(cudnn_install_path=cudnn_install_path))
f.close()
@ -265,11 +270,11 @@ def main():
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
parser.add_argument(
"--cuda_path",
default="/usr/local/cuda",
default=None,
help="Path to the CUDA toolkit.")
parser.add_argument(
"--cudnn_path",
default="/usr/local/cuda",
default=None,
help="Path to CUDNN libraries.")
args = parser.parse_args()
@ -291,8 +296,10 @@ def main():
cudnn_install_path = args.cudnn_path
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
if args.enable_cuda:
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
print("CUDNN library path: {}".format(cudnn_install_path))
if cuda_toolkit_path:
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
if cudnn_install_path:
print("CUDNN library path: {}".format(cudnn_install_path))
write_bazelrc(
python_bin_path=python_bin_path,
tf_need_cuda=1 if args.enable_cuda else 0,