Disable avxvnniint8 when building with Clang version < 19, or GCC < 13.

PiperOrigin-RevId: 712516025
This commit is contained in:
Vladimir Belitskiy 2025-01-06 07:05:20 -08:00 committed by jax authors
parent 512d5450ae
commit f2e210b315
2 changed files with 19 additions and 0 deletions

View File

@ -468,6 +468,9 @@ async def main():
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command_base.append("--config=clang")
if clang_major_version < 19:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
else:
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
logging.debug(
@ -477,6 +480,10 @@ async def main():
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
gcc_major_version = utils.get_gcc_major_version(gcc_path)
if gcc_major_version < 13:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":

View File

@ -201,6 +201,18 @@ def get_clang_major_version(clang_path):
return major_version
def get_gcc_major_version(gcc_path: str):
gcc_version_proc = subprocess.run(
[gcc_path, "-dumpversion"],
check=True,
capture_output=True,
text=True,
)
major_version = int(gcc_version_proc.stdout)
return major_version
def get_jax_configure_bazel_options(bazel_command: list[str]):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so