mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use TF_CUDA_PATHS
CUDA_TOOLKIT_PATH and CUDNN_INSTALL_PATH are deprecated, see TF 2.0 release notes for more information
This commit is contained in:
parent
4fa79ce1cb
commit
2d321c26e6
@ -317,12 +317,18 @@ def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
|
||||
cpu=None, **kwargs):
|
||||
with open("../.bazelrc", "w") as f:
|
||||
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
||||
tf_cuda_paths = []
|
||||
if cuda_toolkit_path:
|
||||
tf_cuda_paths.append(cuda_toolkit_path)
|
||||
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
||||
.format(cuda_toolkit_path=cuda_toolkit_path))
|
||||
if cudnn_install_path:
|
||||
tf_cuda_paths.append(cudnn_install_path)
|
||||
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
|
||||
.format(cudnn_install_path=cudnn_install_path))
|
||||
if len(tf_cuda_paths):
|
||||
f.write("build --action_env TF_CUDA_PATHS=\"{tf_cuda_paths}\"\n"
|
||||
.format(tf_cuda_paths=",".join(tf_cuda_paths)))
|
||||
if cuda_version:
|
||||
f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n"
|
||||
.format(cuda_version=cuda_version))
|
||||
|
Loading…
x
Reference in New Issue
Block a user