mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add flags to configure the cuda_compute_capability and rocm_amd_targets
This commit is contained in:
parent
f47926a23d
commit
d0acd9f343
@ -214,7 +214,8 @@ def check_bazel_version(bazel_path):
|
||||
def write_bazelrc(python_bin_path=None, remote_build=None,
|
||||
cuda_toolkit_path=None, cudnn_install_path=None,
|
||||
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
|
||||
cpu=None):
|
||||
cpu=None, cuda_compute_capabilities=None,
|
||||
rocm_amdgpu_targets=None):
|
||||
tf_cuda_paths = []
|
||||
|
||||
with open("../.jax_configure.bazelrc", "w") as f:
|
||||
@ -245,9 +246,15 @@ def write_bazelrc(python_bin_path=None, remote_build=None,
|
||||
if cudnn_version:
|
||||
f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n"
|
||||
.format(cudnn_version=cudnn_version))
|
||||
if cuda_compute_capabilities:
|
||||
f.write(
|
||||
f'build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"')
|
||||
if rocm_toolkit_path:
|
||||
f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n"
|
||||
.format(rocm_toolkit_path=rocm_toolkit_path))
|
||||
if rocm_amdgpu_targets:
|
||||
f.write(
|
||||
f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"')
|
||||
if cpu is not None:
|
||||
f.write("build --distinct_host_configuration=true\n")
|
||||
f.write(f"build --cpu={cpu}\n")
|
||||
@ -366,6 +373,14 @@ def main():
|
||||
"--cudnn_version",
|
||||
default=None,
|
||||
help="CUDNN version, e.g., 8")
|
||||
parser.add_argument(
|
||||
"--cuda_compute_capabilities",
|
||||
default="3.5,5.2,6.0,6.1,7.0",
|
||||
help="A comma-separated list of CUDA compute capabilities to support.")
|
||||
parser.add_argument(
|
||||
"--rocm_amdgpu_targets",
|
||||
default="gfx900,gfx906,gfx90",
|
||||
help="A comma-separated list of ROCm amdgpu targets to support.")
|
||||
parser.add_argument(
|
||||
"--rocm_path",
|
||||
default=None,
|
||||
@ -439,6 +454,7 @@ def main():
|
||||
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
|
||||
if cudnn_install_path:
|
||||
print("CUDNN library path: {}".format(cudnn_install_path))
|
||||
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
|
||||
if args.cuda_version:
|
||||
print("CUDA version: {}".format(args.cuda_version))
|
||||
if args.cudnn_version:
|
||||
@ -451,6 +467,7 @@ def main():
|
||||
if args.enable_rocm:
|
||||
if rocm_toolkit_path:
|
||||
print("ROCm toolkit path: {}".format(rocm_toolkit_path))
|
||||
print("ROCm amdgpu targets: {}".format(args.rocm_amdgpu_targets))
|
||||
|
||||
write_bazelrc(
|
||||
python_bin_path=python_bin_path,
|
||||
@ -461,6 +478,8 @@ def main():
|
||||
cudnn_version=args.cudnn_version,
|
||||
rocm_toolkit_path=rocm_toolkit_path,
|
||||
cpu=args.target_cpu,
|
||||
cuda_compute_capabilities=args.cuda_compute_capabilities,
|
||||
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
|
||||
)
|
||||
|
||||
print("\nBuilding XLA and installing it in the jaxlib source tree...")
|
||||
|
Loading…
x
Reference in New Issue
Block a user