fix mlir capi dll building and linking

This commit is contained in:
Cloud Han 2021-11-25 00:07:25 +08:00
parent 9781f365a1
commit 317edcdacd
3 changed files with 79 additions and 6 deletions

View File

@ -16,7 +16,7 @@
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "if_windows")
licenses(["notice"]) # Apache 2
@ -39,7 +39,9 @@ py_binary(
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:standard_dialect",
"//jaxlib:pocketfft_flatbuffers_py",
] + if_not_windows([
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + if_not_windows([
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
]) + if_cuda([
"//jaxlib:gpu_support",

View File

@ -33,3 +33,69 @@ flatbuffer_py_library = _flatbuffer_py_library
def py_extension(name, srcs, copts, deps):
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)
def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
"""Workaround DLL building issue.
1. cc_binary with linkshared enabled cannot produce DLL with symbol
correctly exported.
2. Even if the DLL is correctly built, the resulting target cannot be
correctly consumed by other targets.
Args:
name: the name of the output target
out: the name of the output DLL filename
deps: deps
srcs: srcs
"""
# create a dummy library to get the *.def file
dummy_library_name = name + ".dummy.dll"
native.cc_binary(
name = dummy_library_name,
linkshared = 1,
linkstatic = 1,
deps = deps,
)
# .def file with all symbols, not usable
full_def_name = name + ".full.def"
native.filegroup(
name = full_def_name,
srcs = [dummy_library_name],
output_group = "def_file",
)
# filtered def_file, only the needed symbols are included
filtered_def_name = name + ".filtered.def"
filtered_def_file = out + ".def"
native.genrule(
name = filtered_def_name,
srcs = [full_def_name],
outs = [filtered_def_file],
cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep '^\\W*mlir' $(location :{}) >> $@""".format(out, full_def_name),
)
# create the desired library
native.cc_binary(
name = out, # this name must be correct, it will be the filename
linkshared = 1,
deps = deps,
win_def_file = filtered_def_file,
)
# however, the created cc_library (a shared library) cannot be correctly
# consumed by other cc_*...
interface_library_file = out + ".if.lib"
native.filegroup(
name = interface_library_file,
srcs = [out],
output_group = "interface_library",
)
# but this one can be correctly consumed, this is our final product
native.cc_import(
name = name,
interface_library = interface_library_file,
shared_library = out,
)

View File

@ -15,6 +15,7 @@
load(
"//jaxlib:jax.bzl",
"py_extension",
"windows_cc_shared_mlir_library",
)
package(
@ -69,6 +70,10 @@ cc_library(
"@org_tensorflow//tensorflow:macos": [":libjaxlib_mlir_capi.dylib"],
"//conditions:default": [":libjaxlib_mlir_capi.so"],
}),
deps = select({
"@org_tensorflow//tensorflow:windows": [":jaxlib_mlir_capi_dll"],
"//conditions:default": [],
}),
)
cc_library(
@ -99,8 +104,8 @@ cc_binary(
deps = [":jaxlib_mlir_capi_objects"],
)
cc_binary(
name = "jaxlib_mlir_capi.dll",
linkshared = 1,
windows_cc_shared_mlir_library(
name = "jaxlib_mlir_capi_dll",
out = "jaxlib_mlir_capi.dll",
deps = [":jaxlib_mlir_capi_objects"],
)
)