Fix ambiguous cpu definition for JAX wheels.

Should fix the error in https://github.com/jax-ml/jax/actions/runs/13682579939/job/38258344926.

PiperOrigin-RevId: 733838895
This commit is contained in:
jax authors 2025-03-05 12:58:37 -08:00
parent 8df00e2666
commit 3edc068f8c
2 changed files with 27 additions and 2 deletions

View File

@ -522,9 +522,10 @@ def jax_wheel(
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
cpu = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:macos_x86_64": "x86_64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",
"@platforms//cpu:x86_64": "x86_64",
"//jaxlib/tools:linux_aarch64": "aarch64",
"//jaxlib/tools:linux_x86_64": "x86_64",
}),
source_files = source_files,
)

View File

@ -186,6 +186,14 @@ selects.config_setting_group(
],
)
selects.config_setting_group(
name = "macos_x86_64",
match_all = [
"@platforms//cpu:x86_64",
":macos",
],
)
selects.config_setting_group(
name = "win_amd64",
match_all = [
@ -194,6 +202,22 @@ selects.config_setting_group(
],
)
selects.config_setting_group(
name = "linux_x86_64",
match_all = [
"@platforms//cpu:x86_64",
"@platforms//os:linux",
],
)
selects.config_setting_group(
name = "linux_aarch64",
match_all = [
":arm64",
"@platforms//os:linux",
],
)
string_flag(
name = "jaxlib_git_hash",
build_setting_default = "",