Add flags to configure the cuda_compute_capability and rocm_amd_targets

This commit is contained in:
yashkatariya 2021-09-17 08:43:25 -07:00
parent f47926a23d
commit d0acd9f343

View File

@ -214,7 +214,8 @@ def check_bazel_version(bazel_path):
def write_bazelrc(python_bin_path=None, remote_build=None,
cuda_toolkit_path=None, cudnn_install_path=None,
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
cpu=None):
cpu=None, cuda_compute_capabilities=None,
rocm_amdgpu_targets=None):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -245,9 +246,15 @@ def write_bazelrc(python_bin_path=None, remote_build=None,
if cudnn_version:
f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n"
.format(cudnn_version=cudnn_version))
if cuda_compute_capabilities:
f.write(
f'build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"')
if rocm_toolkit_path:
f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n"
.format(rocm_toolkit_path=rocm_toolkit_path))
if rocm_amdgpu_targets:
f.write(
f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"')
if cpu is not None:
f.write("build --distinct_host_configuration=true\n")
f.write(f"build --cpu={cpu}\n")
@ -366,6 +373,14 @@ def main():
"--cudnn_version",
default=None,
help="CUDNN version, e.g., 8")
parser.add_argument(
"--cuda_compute_capabilities",
default="3.5,5.2,6.0,6.1,7.0",
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx900,gfx906,gfx90",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--rocm_path",
default=None,
@ -439,6 +454,7 @@ def main():
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
if cudnn_install_path:
print("CUDNN library path: {}".format(cudnn_install_path))
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
if args.cuda_version:
print("CUDA version: {}".format(args.cuda_version))
if args.cudnn_version:
@ -451,6 +467,7 @@ def main():
if args.enable_rocm:
if rocm_toolkit_path:
print("ROCm toolkit path: {}".format(rocm_toolkit_path))
print("ROCm amdgpu targets: {}".format(args.rocm_amdgpu_targets))
write_bazelrc(
python_bin_path=python_bin_path,
@ -461,6 +478,8 @@ def main():
cudnn_version=args.cudnn_version,
rocm_toolkit_path=rocm_toolkit_path,
cpu=args.target_cpu,
cuda_compute_capabilities=args.cuda_compute_capabilities,
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
)
print("\nBuilding XLA and installing it in the jaxlib source tree...")