mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix build rule for free-threaded python builds.
PiperOrigin-RevId: 733857126
This commit is contained in:
parent
3edc068f8c
commit
0913cd7583
@ -325,11 +325,18 @@ def jax_generate_backend_suites(backends = []):
|
||||
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
|
||||
)
|
||||
|
||||
def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version):
|
||||
def _get_full_wheel_name(
|
||||
package_name,
|
||||
no_abi,
|
||||
platform_independent,
|
||||
platform_name,
|
||||
cpu_name,
|
||||
wheel_version,
|
||||
py_freethreaded):
|
||||
if no_abi or platform_independent:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
|
||||
else:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
|
||||
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}{free_threaded_suffix}-{wheel_platform_tag}.whl"
|
||||
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
|
||||
return wheel_name_template.format(
|
||||
package_name = package_name,
|
||||
@ -339,6 +346,7 @@ def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_na
|
||||
wheel_platform_tag = "any" if platform_independent else "_".join(
|
||||
PLATFORM_TAGS_DICT[platform_name, cpu_name],
|
||||
),
|
||||
free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "",
|
||||
)
|
||||
|
||||
def _get_source_distribution_name(package_name, wheel_version):
|
||||
@ -352,6 +360,7 @@ def _jax_wheel_impl(ctx):
|
||||
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
|
||||
output_path = ctx.attr.output_path[BuildSettingInfo].value
|
||||
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
|
||||
py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value
|
||||
executable = ctx.executable.wheel_binary
|
||||
|
||||
if include_cuda_libs and not override_include_cuda_libs:
|
||||
@ -387,6 +396,7 @@ def _jax_wheel_impl(ctx):
|
||||
platform_name = platform_name,
|
||||
cpu_name = cpu,
|
||||
wheel_version = full_wheel_version,
|
||||
py_freethreaded = py_freethreaded,
|
||||
)
|
||||
wheel_file = ctx.actions.declare_file(output_path +
|
||||
"/" + wheel_name)
|
||||
@ -463,6 +473,7 @@ _jax_wheel = rule(
|
||||
"enable_rocm": attr.bool(default = False),
|
||||
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
|
||||
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
|
||||
"py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")),
|
||||
},
|
||||
implementation = _jax_wheel_impl,
|
||||
executable = False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user