mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
fix mlir capi dll building and linking
This commit is contained in:
parent
9781f365a1
commit
317edcdacd
@ -16,7 +16,7 @@
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "if_windows")
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
@ -39,7 +39,9 @@ py_binary(
|
||||
"//jaxlib/mlir:mhlo_dialect",
|
||||
"//jaxlib/mlir:standard_dialect",
|
||||
"//jaxlib:pocketfft_flatbuffers_py",
|
||||
] + if_not_windows([
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]) + if_not_windows([
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
|
||||
]) + if_cuda([
|
||||
"//jaxlib:gpu_support",
|
||||
|
@ -33,3 +33,69 @@ flatbuffer_py_library = _flatbuffer_py_library
|
||||
|
||||
def py_extension(name, srcs, copts, deps):
|
||||
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)
|
||||
|
||||
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
|
||||
"""Workaround DLL building issue.
|
||||
|
||||
1. cc_binary with linkshared enabled cannot produce DLL with symbol
|
||||
correctly exported.
|
||||
2. Even if the DLL is correctly built, the resulting target cannot be
|
||||
correctly consumed by other targets.
|
||||
|
||||
Args:
|
||||
name: the name of the output target
|
||||
out: the name of the output DLL filename
|
||||
deps: deps
|
||||
srcs: srcs
|
||||
"""
|
||||
|
||||
# create a dummy library to get the *.def file
|
||||
dummy_library_name = name + ".dummy.dll"
|
||||
native.cc_binary(
|
||||
name = dummy_library_name,
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = deps,
|
||||
)
|
||||
|
||||
# .def file with all symbols, not usable
|
||||
full_def_name = name + ".full.def"
|
||||
native.filegroup(
|
||||
name = full_def_name,
|
||||
srcs = [dummy_library_name],
|
||||
output_group = "def_file",
|
||||
)
|
||||
|
||||
# filtered def_file, only the needed symbols are included
|
||||
filtered_def_name = name + ".filtered.def"
|
||||
filtered_def_file = out + ".def"
|
||||
native.genrule(
|
||||
name = filtered_def_name,
|
||||
srcs = [full_def_name],
|
||||
outs = [filtered_def_file],
|
||||
cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep '^\\W*mlir' $(location :{}) >> $@""".format(out, full_def_name),
|
||||
)
|
||||
|
||||
# create the desired library
|
||||
native.cc_binary(
|
||||
name = out, # this name must be correct, it will be the filename
|
||||
linkshared = 1,
|
||||
deps = deps,
|
||||
win_def_file = filtered_def_file,
|
||||
)
|
||||
|
||||
# however, the created cc_library (a shared library) cannot be correctly
|
||||
# consumed by other cc_*...
|
||||
interface_library_file = out + ".if.lib"
|
||||
native.filegroup(
|
||||
name = interface_library_file,
|
||||
srcs = [out],
|
||||
output_group = "interface_library",
|
||||
)
|
||||
|
||||
# but this one can be correctly consumed, this is our final product
|
||||
native.cc_import(
|
||||
name = name,
|
||||
interface_library = interface_library_file,
|
||||
shared_library = out,
|
||||
)
|
||||
|
@ -15,6 +15,7 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_extension",
|
||||
"windows_cc_shared_mlir_library",
|
||||
)
|
||||
|
||||
package(
|
||||
@ -69,6 +70,10 @@ cc_library(
|
||||
"@org_tensorflow//tensorflow:macos": [":libjaxlib_mlir_capi.dylib"],
|
||||
"//conditions:default": [":libjaxlib_mlir_capi.so"],
|
||||
}),
|
||||
deps = select({
|
||||
"@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi_dll"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -99,8 +104,8 @@ cc_binary(
|
||||
deps = [":jaxlib_mlir_capi_objects"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "jaxlib_mlir_capi.dll",
|
||||
linkshared = 1,
|
||||
windows_cc_shared_mlir_library(
|
||||
name = "jaxlib_mlir_capi_dll",
|
||||
out = "jaxlib_mlir_capi.dll",
|
||||
deps = [":jaxlib_mlir_capi_objects"],
|
||||
)
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user