diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 633cd07ab..a5f02937c 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -522,9 +522,10 @@ def jax_wheel( # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. cpu = select({ "//jaxlib/tools:macos_arm64": "arm64", + "//jaxlib/tools:macos_x86_64": "x86_64", "//jaxlib/tools:win_amd64": "AMD64", - "//jaxlib/tools:arm64": "aarch64", - "@platforms//cpu:x86_64": "x86_64", + "//jaxlib/tools:linux_aarch64": "aarch64", + "//jaxlib/tools:linux_x86_64": "x86_64", }), source_files = source_files, ) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 6eab64823..baf996d50 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -186,6 +186,14 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "macos_x86_64", + match_all = [ + "@platforms//cpu:x86_64", + ":macos", + ], +) + selects.config_setting_group( name = "win_amd64", match_all = [ @@ -194,6 +202,22 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "linux_x86_64", + match_all = [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], +) + +selects.config_setting_group( + name = "linux_aarch64", + match_all = [ + ":arm64", + "@platforms//os:linux", + ], +) + string_flag( name = "jaxlib_git_hash", build_setting_default = "",