From d0acd9f3435a7eff924c9be937dc2abb1716e79a Mon Sep 17 00:00:00 2001 From: yashkatariya Date: Fri, 17 Sep 2021 08:43:25 -0700 Subject: [PATCH] Add flags to configure the cuda_compute_capability and rocm_amd_targets --- build/build.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/build/build.py b/build/build.py index 48f4532f7..f1b544d44 100755 --- a/build/build.py +++ b/build/build.py @@ -214,7 +214,8 @@ def check_bazel_version(bazel_path): def write_bazelrc(python_bin_path=None, remote_build=None, cuda_toolkit_path=None, cudnn_install_path=None, cuda_version=None, cudnn_version=None, rocm_toolkit_path=None, - cpu=None): + cpu=None, cuda_compute_capabilities=None, + rocm_amdgpu_targets=None): tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: @@ -245,9 +246,15 @@ def write_bazelrc(python_bin_path=None, remote_build=None, if cudnn_version: f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n" .format(cudnn_version=cudnn_version)) + if cuda_compute_capabilities: + f.write( + f'build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"') if rocm_toolkit_path: f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" .format(rocm_toolkit_path=rocm_toolkit_path)) + if rocm_amdgpu_targets: + f.write( + f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"') if cpu is not None: f.write("build --distinct_host_configuration=true\n") f.write(f"build --cpu={cpu}\n") @@ -366,6 +373,14 @@ def main(): "--cudnn_version", default=None, help="CUDNN version, e.g., 8") + parser.add_argument( + "--cuda_compute_capabilities", + default="3.5,5.2,6.0,6.1,7.0", + help="A comma-separated list of CUDA compute capabilities to support.") + parser.add_argument( + "--rocm_amdgpu_targets", + default="gfx900,gfx906,gfx90", + help="A comma-separated list of ROCm amdgpu targets to support.") parser.add_argument( "--rocm_path", default=None, @@ -439,6 +454,7 @@ def main(): print("CUDA toolkit path: {}".format(cuda_toolkit_path)) if cudnn_install_path: print("CUDNN library path: {}".format(cudnn_install_path)) + print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities)) if args.cuda_version: print("CUDA version: {}".format(args.cuda_version)) if args.cudnn_version: @@ -451,6 +467,7 @@ def main(): if args.enable_rocm: if rocm_toolkit_path: print("ROCm toolkit path: {}".format(rocm_toolkit_path)) + print("ROCm amdgpu targets: {}".format(args.rocm_amdgpu_targets)) write_bazelrc( python_bin_path=python_bin_path, @@ -461,6 +478,8 @@ def main(): cudnn_version=args.cudnn_version, rocm_toolkit_path=rocm_toolkit_path, cpu=args.target_cpu, + cuda_compute_capabilities=args.cuda_compute_capabilities, + rocm_amdgpu_targets=args.rocm_amdgpu_targets, ) print("\nBuilding XLA and installing it in the jaxlib source tree...")