From 0913cd7583ca927b1df22589aa7fd2e169b1245a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 5 Mar 2025 13:53:44 -0800 Subject: [PATCH] Fix build rule for free-threaded python builds. PiperOrigin-RevId: 733857126 --- jaxlib/jax.bzl | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index a5f02937c..58a83d9b0 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -325,11 +325,18 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) -def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version): +def _get_full_wheel_name( + package_name, + no_abi, + platform_independent, + platform_name, + cpu_name, + wheel_version, + py_freethreaded): if no_abi or platform_independent: wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" else: - wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}{free_threaded_suffix}-{wheel_platform_tag}.whl" python_version = HERMETIC_PYTHON_VERSION.replace(".", "") return wheel_name_template.format( package_name = package_name, @@ -339,6 +346,7 @@ def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_na wheel_platform_tag = "any" if platform_independent else "_".join( PLATFORM_TAGS_DICT[platform_name, cpu_name], ), + free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) def _get_source_distribution_name(package_name, wheel_version): @@ -352,6 +360,7 @@ def _jax_wheel_impl(ctx): override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value output_path = ctx.attr.output_path[BuildSettingInfo].value git_hash = ctx.attr.git_hash[BuildSettingInfo].value + py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value executable = ctx.executable.wheel_binary if include_cuda_libs and not override_include_cuda_libs: @@ -387,6 +396,7 @@ def _jax_wheel_impl(ctx): platform_name = platform_name, cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) wheel_file = ctx.actions.declare_file(output_path + "/" + wheel_name) @@ -463,6 +473,7 @@ _jax_wheel = rule( "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), + "py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")), }, implementation = _jax_wheel_impl, executable = False,