mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Disable avxvnniint8
when building with Clang version < 19, or GCC < 13.
PiperOrigin-RevId: 712516025
This commit is contained in:
parent
512d5450ae
commit
f2e210b315
@ -468,6 +468,9 @@ async def main():
|
||||
# Enable clang settings that are needed for the build to work with newer
|
||||
# versions of Clang.
|
||||
wheel_build_command_base.append("--config=clang")
|
||||
if clang_major_version < 19:
|
||||
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
||||
|
||||
else:
|
||||
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
|
||||
logging.debug(
|
||||
@ -477,6 +480,10 @@ async def main():
|
||||
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
|
||||
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
|
||||
|
||||
gcc_major_version = utils.get_gcc_major_version(gcc_path)
|
||||
if gcc_major_version < 13:
|
||||
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
||||
|
||||
if not args.disable_mkl_dnn:
|
||||
logging.debug("Enabling MKL DNN")
|
||||
if target_cpu == "aarch64":
|
||||
|
@ -201,6 +201,18 @@ def get_clang_major_version(clang_path):
|
||||
|
||||
return major_version
|
||||
|
||||
def get_gcc_major_version(gcc_path: str):
|
||||
gcc_version_proc = subprocess.run(
|
||||
[gcc_path, "-dumpversion"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
major_version = int(gcc_version_proc.stdout)
|
||||
|
||||
return major_version
|
||||
|
||||
|
||||
def get_jax_configure_bazel_options(bazel_command: list[str]):
|
||||
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
|
||||
# Get the index of the "run" parameter. Build options will come after "run" so
|
||||
|
Loading…
x
Reference in New Issue
Block a user