fix rocm_amdgpu_targets for rocm

This commit is contained in:
Reza Rahimi 2021-08-17 20:42:02 +00:00
parent 476642578b
commit f454f6b7b8
2 changed files with 2 additions and 12 deletions

View File

@ -46,7 +46,7 @@ build:native_arch_posix --host_copt=-march=native
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES="3.5,5.2,6.0,6.1,7.0"
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="3.5,5.2,6.0,6.1,7.0"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda
build:cuda --define=xla_python_enable_gpu=true
@ -55,7 +55,7 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --define=xla_python_enable_gpu=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env=TF_ROCM_AMDGPU_TARGETS="gfx803,gfx900,gfx906,gfx1010"
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908"
build:nonccl --define=no_nccl_support=true

View File

@ -374,18 +374,10 @@ def main():
"--cudnn_version",
default=None,
help="CUDNN version, e.g., 8")
parser.add_argument(
"--cuda_compute_capabilities",
default="3.5,5.2,6.0,6.1,7.0",
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_path",
default=None,
help="Path to the ROCm toolkit.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx803,gfx900,gfx906,gfx1010",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--bazel_startup_options",
action="append", default=[],
@ -457,7 +449,6 @@ def main():
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
if cudnn_install_path:
print("CUDNN library path: {}".format(cudnn_install_path))
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
if args.cuda_version:
print("CUDA version: {}".format(args.cuda_version))
if args.cudnn_version:
@ -470,7 +461,6 @@ def main():
if args.enable_rocm:
if rocm_toolkit_path:
print("ROCm toolkit path: {}".format(rocm_toolkit_path))
print("ROCm amdgpu targets: {}".format(args.rocm_amdgpu_targets))
write_bazelrc(
python_bin_path=python_bin_path,