mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

This lets us avoid bundling a whole another copy of LLVM with JAX packages and so we can finally start building Mosaic GPU by default. PiperOrigin-RevId: 638569750
344 lines
17 KiB
Plaintext
344 lines
17 KiB
Plaintext
############################################################################
|
|
# All default build options below.
|
|
|
|
# Required by OpenXLA
|
|
# https://github.com/openxla/xla/issues/1323
|
|
build --nocheck_visibility
|
|
|
|
# Sets the default Apple platform to macOS.
|
|
build --apple_platform_type=macos
|
|
build --macos_minimum_os=10.14
|
|
|
|
# Make Bazel print out all options from rc files.
|
|
build --announce_rc
|
|
|
|
build --define open_source_build=true
|
|
|
|
build --spawn_strategy=standalone
|
|
|
|
build --enable_platform_specific_config
|
|
|
|
build --experimental_cc_shared_library
|
|
|
|
# Disable enabled-by-default TensorFlow features that we don't care about.
|
|
build --define=no_aws_support=true
|
|
build --define=no_gcp_support=true
|
|
build --define=no_hdfs_support=true
|
|
build --define=no_kafka_support=true
|
|
build --define=no_ignite_support=true
|
|
|
|
build --define=grpc_no_ares=true
|
|
|
|
build --define=tsl_link_protobuf=true
|
|
|
|
build -c opt
|
|
|
|
build --config=short_logs
|
|
|
|
build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
|
|
|
|
# Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled,
|
|
# these values are overridden.
|
|
build --@xla//xla/python:enable_gpu=false
|
|
|
|
###########################################################################
|
|
|
|
build:posix --copt=-fvisibility=hidden
|
|
build:posix --copt=-Wno-sign-compare
|
|
build:posix --cxxopt=-std=c++17
|
|
build:posix --host_cxxopt=-std=c++17
|
|
|
|
build:avx_posix --copt=-mavx
|
|
build:avx_posix --host_copt=-mavx
|
|
|
|
build:avx_windows --copt=/arch=AVX
|
|
|
|
build:avx_linux --copt=-mavx
|
|
build:avx_linux --host_copt=-mavx
|
|
|
|
build:native_arch_posix --copt=-march=native
|
|
build:native_arch_posix --host_copt=-march=native
|
|
|
|
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
|
|
|
|
build:cuda --repo_env TF_NEED_CUDA=1
|
|
build:cuda --repo_env TF_NCCL_USE_STUB=1
|
|
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
|
|
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
|
|
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
|
|
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
|
build:cuda --@local_config_cuda//:enable_cuda
|
|
build:cuda --@xla//xla/python:enable_gpu=true
|
|
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
|
|
build:cuda --define=xla_python_enable_gpu=true
|
|
|
|
# Build with nvcc for CUDA and clang for host
|
|
build:nvcc_clang --config=cuda
|
|
# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang
|
|
build:nvcc_clang --action_env=TF_CUDA_CLANG="1"
|
|
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
|
|
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
|
|
|
|
# Requires MSVC and LLVM to be installed
|
|
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
|
|
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
|
|
build:win_clang --compiler=clang-cl
|
|
|
|
# Later Bazel flag values override earlier values.
|
|
# TODO(jieying): remove enable_gpu and xla_python_enable_gpu from build:cuda
|
|
# after the pluin is released.
|
|
build:cuda_plugin --@xla//xla/python:enable_gpu=false
|
|
build:cuda_plugin --define=xla_python_enable_gpu=false
|
|
|
|
# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
|
|
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
|
|
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
|
|
# packages.
|
|
# This has pros and cons:
|
|
# * pro: we'll ignore other CUDA installations, which has frequently confused
|
|
# users in the past. By setting RPATH, we'll always use the NVIDIA pip
|
|
# packages if they are installed.
|
|
# * con: the user cannot override the CUDA installation location
|
|
# via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is
|
|
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
|
|
# The list of CUDA pip packages that JAX depends on are present in setup.py.
|
|
build:cuda --linkopt=-Wl,--disable-new-dtags
|
|
|
|
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
|
|
build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-17/bin/clang"
|
|
build:cuda_clang --action_env=TF_CUDA_CLANG="1"
|
|
# Disable clang extention that rejects type definitions within offsetof.
|
|
# This was added in clang-16 by https://reviews.llvm.org/D133574.
|
|
# Can be removed once upb is updated, since a type definition is used within
|
|
# offset of in the current version of ubp.
|
|
# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
|
|
build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
|
|
# Disable clang extention that rejects unknown arguments.
|
|
build:cuda_clang --copt=-Qunused-arguments
|
|
|
|
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
|
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
|
build:rocm --@xla//xla/python:enable_gpu=true
|
|
build:rocm --define=xla_python_enable_gpu=true
|
|
build:rocm --repo_env TF_NEED_ROCM=1
|
|
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
|
|
|
|
build:nonccl --define=no_nccl_support=true
|
|
|
|
# Windows has a relatively short command line limit, which JAX has begun to hit.
|
|
# See https://docs.bazel.build/versions/main/windows.html
|
|
build:windows --features=compiler_param_file
|
|
build:windows --features=archive_param_file
|
|
|
|
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
|
# _USE_MATH_DEFINES is defined.
|
|
build:windows --copt=/D_USE_MATH_DEFINES
|
|
build:windows --host_copt=/D_USE_MATH_DEFINES
|
|
# Make sure to include as little of windows.h as possible
|
|
build:windows --copt=-DWIN32_LEAN_AND_MEAN
|
|
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
|
|
build:windows --copt=-DNOGDI
|
|
build:windows --host_copt=-DNOGDI
|
|
# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/
|
|
# otherwise, there will be some compiling error due to preprocessing.
|
|
build:windows --copt=/Zc:preprocessor
|
|
build:windows --cxxopt=/std:c++17
|
|
build:windows --host_cxxopt=/std:c++17
|
|
# Generate PDB files, to generate useful PDBs, in opt compilation_mode
|
|
# --copt /Z7 is needed.
|
|
build:windows --linkopt=/DEBUG
|
|
build:windows --host_linkopt=/DEBUG
|
|
build:windows --linkopt=/OPT:REF
|
|
build:windows --host_linkopt=/OPT:REF
|
|
build:windows --linkopt=/OPT:ICF
|
|
build:windows --host_linkopt=/OPT:ICF
|
|
build:windows --incompatible_strict_action_env=true
|
|
|
|
build:linux --config=posix
|
|
build:linux --copt=-Wno-unknown-warning-option
|
|
# Workaround for gcc 10+ warnings related to upb.
|
|
# See https://github.com/tensorflow/tensorflow/issues/39467
|
|
build:linux --copt=-Wno-stringop-truncation
|
|
build:linux --copt=-Wno-array-parameter
|
|
|
|
build:macos --config=posix
|
|
|
|
# Public cache for macOS builds. The "oct2023" in the URL is just the
|
|
# date when the bucket was created and can be disregarded. It still contains the
|
|
# latest cache that is being used.
|
|
build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false
|
|
# Cache pushes are limited to Jax's CI system.
|
|
build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials
|
|
|
|
# Suppress all warning messages.
|
|
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
|
|
|
#########################################################################
|
|
# RBE config options below.
|
|
# Flag to enable remote config
|
|
common --experimental_repo_remote_exec
|
|
|
|
build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
|
|
build:rbe --google_default_credentials
|
|
build:rbe --bes_backend=buildeventservice.googleapis.com
|
|
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
|
|
build:rbe --bes_timeout=600s
|
|
build:rbe --define=EXECUTOR=remote
|
|
build:rbe --flaky_test_attempts=3
|
|
build:rbe --jobs=200
|
|
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
|
|
build:rbe --remote_timeout=3600
|
|
build:rbe --spawn_strategy=remote,worker,standalone,local
|
|
test:rbe --test_env=USER=anon
|
|
# Attempt to minimize the amount of data transfer between bazel and the remote
|
|
# workers:
|
|
build:rbe --remote_download_toplevel
|
|
|
|
build:rbe_linux --config=rbe
|
|
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
|
|
build:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8
|
|
build:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8
|
|
build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
|
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
|
|
|
# Non-rbe settings we should include because we do not run configure
|
|
build:rbe_linux --config=avx_linux
|
|
build:rbe_linux --linkopt=-lrt
|
|
build:rbe_linux --host_linkopt=-lrt
|
|
build:rbe_linux --linkopt=-lm
|
|
build:rbe_linux --host_linkopt=-lm
|
|
|
|
# Use the GPU toolchain until the CPU one is ready.
|
|
# https://github.com/bazelbuild/bazel/issues/13623
|
|
build:rbe_cpu_linux_base --config=rbe_linux
|
|
build:rbe_cpu_linux_base --config=cuda_clang
|
|
build:rbe_cpu_linux_base --action_env=TF_NVCC_CLANG="1"
|
|
build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
|
|
build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
|
|
build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64"
|
|
build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
|
|
build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
|
|
build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
|
|
|
|
build:rbe_cpu_linux_py3.9 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
|
|
build:rbe_cpu_linux_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
|
|
build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
|
|
build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
|
|
build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
|
|
build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
|
|
build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12"
|
|
build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
|
|
|
|
build:rbe_linux_cuda_base --config=rbe_linux
|
|
build:rbe_linux_cuda_base --config=cuda
|
|
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
|
|
|
|
build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
|
|
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang
|
|
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1"
|
|
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDA_VERSION=12
|
|
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=9
|
|
build:rbe_linux_cuda12.3_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12"
|
|
build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
|
|
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
|
|
build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain"
|
|
build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64"
|
|
build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
|
|
build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
|
|
build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
|
|
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda"
|
|
build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl"
|
|
# RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility
|
|
build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat
|
|
build:rbe_linux_cuda12.3_nvcc_py3.9 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.9"
|
|
build:rbe_linux_cuda12.3_nvcc_py3.9 --repo_env HERMETIC_PYTHON_VERSION="3.9"
|
|
build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10"
|
|
build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10"
|
|
build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11"
|
|
build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
|
|
build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12"
|
|
build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
|
|
|
|
# These you may need to change for your own GCP project.
|
|
build:tensorflow_testing_rbe --project_id=tensorflow-testing
|
|
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
|
|
build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe
|
|
|
|
# START CROSS-COMPILE CONFIGS
|
|
|
|
# Set execution platform to Linux x86
|
|
# Note: Lot of the "host_" flags such as "host_cpu" and "host_crosstool_top"
|
|
# flags seem to be actually used to specify the execution platform details. It
|
|
# seems it is this way because these flags are old and predate the distinction
|
|
# between host and execution platform.
|
|
build:cross_compile_base --host_cpu=k8
|
|
build:cross_compile_base --host_crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite
|
|
build:cross_compile_base --extra_execution_platforms=@xla//tools/toolchains/cross_compile/config:linux_x86_64
|
|
|
|
# START LINUX AARCH64 CROSS-COMPILE CONFIGS
|
|
build:cross_compile_linux_arm64 --config=cross_compile_base
|
|
|
|
# Set the target CPU to Aarch64
|
|
build:cross_compile_linux_arm64 --platforms=@xla//tools/toolchains/cross_compile/config:linux_aarch64
|
|
build:cross_compile_linux_arm64 --cpu=aarch64
|
|
build:cross_compile_linux_arm64 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite
|
|
|
|
build:rbe_cross_compile_base --config=rbe
|
|
|
|
# RBE cross-compile configs for Linux Aarch64
|
|
build:rbe_cross_compile_linux_arm64 --config=cross_compile_linux_arm64
|
|
build:rbe_cross_compile_linux_arm64 --config=rbe_cross_compile_base
|
|
# END LINUX AARCH64 CROSS-COMPILE CONFIGS
|
|
|
|
# START MACOS CROSS-COMPILE CONFIGS
|
|
build:cross_compile_macos_x86 --config=cross_compile_base
|
|
build:cross_compile_macos_x86 --config=nonccl
|
|
# Target Catalina (10.15) as the minimum supported OS
|
|
build:cross_compile_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15
|
|
|
|
# Set the target CPU to Darwin x86
|
|
build:cross_compile_macos_x86 --platforms=@xla//tools/toolchains/cross_compile/config:darwin_x86_64
|
|
build:cross_compile_macos_x86 --cpu=darwin
|
|
build:cross_compile_macos_x86 --crosstool_top=@xla//tools/toolchains/cross_compile/cc:cross_compile_toolchain_suite
|
|
# When RBE cross-compiling for macOS, we need to explicitly register the
|
|
# toolchain. Otherwise, oddly, RBE complains that a "docker container must be
|
|
# specified".
|
|
build:cross_compile_macos_x86 --extra_toolchains=@xla//tools/toolchains/cross_compile/config:macos-x86-cross-compile-cc-toolchain
|
|
# Map --platforms=darwin_x86_64 to --cpu=darwin and vice-versa to make selects()
|
|
# and transistions that use these flags work. The flag --platform_mappings needs
|
|
# to be set to a file that exists relative to the package path roots.
|
|
build:cross_compile_macos_x86 --platform_mappings=platform_mappings
|
|
|
|
# RBE cross-compile configs for Darwin x86
|
|
build:rbe_cross_compile_macos_x86 --config=cross_compile_macos_x86
|
|
build:rbe_cross_compile_macos_x86 --config=rbe_cross_compile_base
|
|
# END MACOS CROSS-COMPILE CONFIGS
|
|
|
|
# END CROSS-COMPILE CONFIGS
|
|
|
|
#############################################################################
|
|
|
|
#############################################################################
|
|
# Some configs to make getting some forms of debug builds. In general, the
|
|
# codebase is only regularly built with optimizations. Use 'debug_symbols' to
|
|
# just get symbols for the parts of XLA/PJRT that jaxlib uses.
|
|
# Or try 'debug' to get a build with assertions enabled and minimal
|
|
# optimizations.
|
|
# Include these in a local .bazelrc.user file as:
|
|
# build --config=debug_symbols
|
|
# Or:
|
|
# build --config=debug
|
|
#
|
|
# Additional files can be opted in for debug symbols by adding patterns
|
|
# to a per_file_copt similar to below.
|
|
#############################################################################
|
|
|
|
build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3"
|
|
build:debug --config debug_symbols -c fastbuild
|
|
|
|
# Load `.jax_configure.bazelrc` file written by build.py
|
|
try-import %workspace%/.jax_configure.bazelrc
|
|
|
|
# Load rc file with user-specific options.
|
|
try-import %workspace%/.bazelrc.user
|