Fix build rule for free-threaded python builds.

PiperOrigin-RevId: 733857126
This commit is contained in:
jax authors 2025-03-05 13:53:44 -08:00
parent 3edc068f8c
commit 0913cd7583

View File

@ -325,11 +325,18 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)
def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version):
def _get_full_wheel_name(
package_name,
no_abi,
platform_independent,
platform_name,
cpu_name,
wheel_version,
py_freethreaded):
if no_abi or platform_independent:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}{free_threaded_suffix}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
return wheel_name_template.format(
package_name = package_name,
@ -339,6 +346,7 @@ def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_na
wheel_platform_tag = "any" if platform_independent else "_".join(
PLATFORM_TAGS_DICT[platform_name, cpu_name],
),
free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "",
)
def _get_source_distribution_name(package_name, wheel_version):
@ -352,6 +360,7 @@ def _jax_wheel_impl(ctx):
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value
executable = ctx.executable.wheel_binary
if include_cuda_libs and not override_include_cuda_libs:
@ -387,6 +396,7 @@ def _jax_wheel_impl(ctx):
platform_name = platform_name,
cpu_name = cpu,
wheel_version = full_wheel_version,
py_freethreaded = py_freethreaded,
)
wheel_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
@ -463,6 +473,7 @@ _jax_wheel = rule(
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
"py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")),
},
implementation = _jax_wheel_impl,
executable = False,