mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update XLA version to include CUDA version check fix and newer XLA.
Make specification of CUDA environment variables optional. Fixes #627. Fixes #276, although the fix requires a new Jaxlib release.
This commit is contained in:
parent
b2160fdc03
commit
f60d927df8
@ -23,10 +23,10 @@ http_archive(
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "6248e6c47b18ab0f496edfa88cac58208189b8b26c7efd2f37ce6a4433159da9",
|
||||
strip_prefix = "tensorflow-0e6271d916cbca4e12d61c94dae0230e3a47cd39",
|
||||
sha256 = "732ccdb272a275014c28e757c30a4575dc100865a8cb30ffc4c39b7d4bc761a2",
|
||||
strip_prefix = "tensorflow-afeb3a07811b201f53e365807ac5173e4bad1c2e",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/0e6271d916cbca4e12d61c94dae0230e3a47cd39.tar.gz",
|
||||
"https://github.com/tensorflow/tensorflow/archive/afeb3a07811b201f53e365807ac5173e4bad1c2e.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -161,8 +161,6 @@ BAZELRC_TEMPLATE = """
|
||||
build --action_env PYTHON_BIN_PATH="{python_bin_path}"
|
||||
build --python_path="{python_bin_path}"
|
||||
build --action_env TF_NEED_CUDA="{tf_need_cuda}"
|
||||
build --action_env CUDA_TOOLKIT_PATH="{cuda_toolkit_path}"
|
||||
build --action_env CUDNN_INSTALL_PATH="{cudnn_install_path}"
|
||||
build --distinct_host_configuration=false
|
||||
build --copt=-Wno-sign-compare
|
||||
build -c opt
|
||||
@ -186,9 +184,16 @@ build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
|
||||
"""
|
||||
|
||||
|
||||
def write_bazelrc(**kwargs):
|
||||
|
||||
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
|
||||
f = open("../.bazelrc", "w")
|
||||
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
||||
if cuda_toolkit_path:
|
||||
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
||||
.format(cuda_toolkit_path=cuda_toolkit_path))
|
||||
if cudnn_install_path:
|
||||
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
|
||||
.format(cudnn_install_path=cudnn_install_path))
|
||||
f.close()
|
||||
|
||||
|
||||
@ -265,11 +270,11 @@ def main():
|
||||
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
|
||||
parser.add_argument(
|
||||
"--cuda_path",
|
||||
default="/usr/local/cuda",
|
||||
default=None,
|
||||
help="Path to the CUDA toolkit.")
|
||||
parser.add_argument(
|
||||
"--cudnn_path",
|
||||
default="/usr/local/cuda",
|
||||
default=None,
|
||||
help="Path to CUDNN libraries.")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -291,8 +296,10 @@ def main():
|
||||
cudnn_install_path = args.cudnn_path
|
||||
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
|
||||
if args.enable_cuda:
|
||||
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
|
||||
print("CUDNN library path: {}".format(cudnn_install_path))
|
||||
if cuda_toolkit_path:
|
||||
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
|
||||
if cudnn_install_path:
|
||||
print("CUDNN library path: {}".format(cudnn_install_path))
|
||||
write_bazelrc(
|
||||
python_bin_path=python_bin_path,
|
||||
tf_need_cuda=1 if args.enable_cuda else 0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user