mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add targets for jaxlib
, jax-cuda-plugin
and jax-cuda-pjrt
editable wheels.
PiperOrigin-RevId: 731737119
This commit is contained in:
parent
f93c2a1aa5
commit
401d315091
@ -374,27 +374,34 @@ def _jax_wheel_impl(ctx):
|
||||
no_abi = ctx.attr.no_abi
|
||||
platform_independent = ctx.attr.platform_independent
|
||||
build_wheel_only = ctx.attr.build_wheel_only
|
||||
editable = ctx.attr.editable
|
||||
platform_name = ctx.attr.platform_name
|
||||
wheel_name = _get_full_wheel_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
no_abi = no_abi,
|
||||
platform_independent = platform_independent,
|
||||
platform_name = platform_name,
|
||||
cpu_name = cpu,
|
||||
wheel_version = full_wheel_version,
|
||||
)
|
||||
wheel_file = ctx.actions.declare_file(output_path +
|
||||
"/" + wheel_name)
|
||||
wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")]
|
||||
outputs = [wheel_file]
|
||||
if not build_wheel_only:
|
||||
source_distribution_name = _get_source_distribution_name(
|
||||
if editable:
|
||||
output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name)
|
||||
wheel_dir = output_dir.path
|
||||
outputs = [output_dir]
|
||||
args.add("--editable")
|
||||
else:
|
||||
wheel_name = _get_full_wheel_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
no_abi = no_abi,
|
||||
platform_independent = platform_independent,
|
||||
platform_name = platform_name,
|
||||
cpu_name = cpu,
|
||||
wheel_version = full_wheel_version,
|
||||
)
|
||||
source_distribution_file = ctx.actions.declare_file(output_path +
|
||||
"/" + source_distribution_name)
|
||||
outputs.append(source_distribution_file)
|
||||
wheel_file = ctx.actions.declare_file(output_path +
|
||||
"/" + wheel_name)
|
||||
wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")]
|
||||
outputs = [wheel_file]
|
||||
if not build_wheel_only:
|
||||
source_distribution_name = _get_source_distribution_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
wheel_version = full_wheel_version,
|
||||
)
|
||||
source_distribution_file = ctx.actions.declare_file(output_path +
|
||||
"/" + source_distribution_name)
|
||||
outputs.append(source_distribution_file)
|
||||
|
||||
args.add("--output_path", wheel_dir) # required argument
|
||||
if not platform_independent:
|
||||
@ -445,6 +452,7 @@ _jax_wheel = rule(
|
||||
"no_abi": attr.bool(default = False),
|
||||
"platform_independent": attr.bool(default = False),
|
||||
"build_wheel_only": attr.bool(default = True),
|
||||
"editable": attr.bool(default = False),
|
||||
"cpu": attr.string(mandatory = True),
|
||||
"platform_name": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
|
||||
@ -469,6 +477,7 @@ def jax_wheel(
|
||||
no_abi = False,
|
||||
platform_independent = False,
|
||||
build_wheel_only = True,
|
||||
editable = False,
|
||||
enable_cuda = False,
|
||||
platform_version = "",
|
||||
source_files = []):
|
||||
@ -482,6 +491,7 @@ def jax_wheel(
|
||||
wheel_name: the name of the wheel
|
||||
no_abi: whether to build a wheel without ABI
|
||||
build_wheel_only: whether to build a wheel without source distribution
|
||||
editable: whether to build an editable wheel
|
||||
platform_independent: whether to build a wheel without platform tag
|
||||
enable_cuda: whether to build a cuda wheel
|
||||
platform_version: the cuda version to use for the wheel
|
||||
@ -497,6 +507,7 @@ def jax_wheel(
|
||||
no_abi = no_abi,
|
||||
platform_independent = platform_independent,
|
||||
build_wheel_only = build_wheel_only,
|
||||
editable = editable,
|
||||
enable_cuda = enable_cuda,
|
||||
platform_version = platform_version,
|
||||
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
|
||||
|
@ -211,6 +211,13 @@ jax_wheel(
|
||||
wheel_name = "jaxlib",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jaxlib_wheel_editable",
|
||||
editable = True,
|
||||
wheel_binary = ":build_wheel",
|
||||
wheel_name = "jaxlib",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_plugin_wheel",
|
||||
enable_cuda = True,
|
||||
@ -221,6 +228,16 @@ jax_wheel(
|
||||
wheel_name = "jax_cuda12_plugin",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_plugin_wheel_editable",
|
||||
editable = True,
|
||||
enable_cuda = True,
|
||||
# TODO(b/371217563) May use hermetic cuda version here.
|
||||
platform_version = "12",
|
||||
wheel_binary = ":build_gpu_kernels_wheel",
|
||||
wheel_name = "jax_cuda12_plugin",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_pjrt_wheel",
|
||||
enable_cuda = True,
|
||||
@ -231,6 +248,16 @@ jax_wheel(
|
||||
wheel_name = "jax_cuda12_pjrt",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_pjrt_wheel_editable",
|
||||
editable = True,
|
||||
enable_cuda = True,
|
||||
# TODO(b/371217563) May use hermetic cuda version here.
|
||||
platform_version = "12",
|
||||
wheel_binary = ":build_gpu_plugin_wheel",
|
||||
wheel_name = "jax_cuda12_pjrt",
|
||||
)
|
||||
|
||||
AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])
|
||||
|
||||
PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])
|
||||
|
Loading…
x
Reference in New Issue
Block a user