Adds a wrapper to sparse tensor dialect, as part of an

an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.

PiperOrigin-RevId: 443438043
This commit is contained in:
Aart Bik 2022-04-21 11:48:16 -07:00 committed by jax authors
parent bef5e02816
commit c1261ccd27
5 changed files with 59 additions and 1 deletions

View File

@ -38,6 +38,7 @@ py_binary(
"//jaxlib/mlir:ir",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:func_dialect",
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib:pocketfft_flatbuffers_py",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",

View File

@ -245,6 +245,8 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_ods_common.py"), dst_dir=mlir_dialects_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_func_ops_ext.py"), dst_dir=mlir_dialects_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_func_ops_gen.py"), dst_dir=mlir_dialects_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py"), dst_dir=mlir_dialects_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/sparse_tensor.py"), dst_dir=mlir_dialects_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/builtin.py"), dst_dir=mlir_dialects_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/chlo.py"), dst_dir=mlir_dialects_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/mhlo.py"), dst_dir=mlir_dialects_dir)
@ -255,16 +257,22 @@ def prepare_wheel(sources_path):
if _is_windows():
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlir.pyd"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.pyd"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.pyd"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.pyd"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
elif _is_mac():
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlir.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
else:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlir.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so"), dst_dir=mlir_libs_dir)
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
patch_copy_xla_extension_stubs(jaxlib_dir)

View File

@ -16,4 +16,10 @@
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.func as func
try:
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
except (ModuleNotFoundError, ImportError):
# TODO(ajcbik,phawkins): make this unconditional when jaxlib > 0.3.7
# is the minimum version.
pass

View File

@ -81,6 +81,21 @@ symlink_inputs(
],
)
symlink_inputs(
name = "sparse_tensor_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:SparseTensorOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
"//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor",
"//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses",
],
)
symlink_inputs(
name = "mhlo_dialect",
rule = py_library,

View File

@ -43,6 +43,33 @@ py_extension(
],
)
py_extension(
name = "_mlirDialectsSparseTensor",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp",
],
copts = COPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11",
],
)
py_extension(
name = "_mlirSparseTensorPasses",
srcs = [
"@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp",
],
copts = COPTS,
deps = [
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPISparseTensorHeaders",
"@pybind11",
],
)
##---------------------------------------------------------------------------##
# MHLO Extensions
##---------------------------------------------------------------------------##
@ -80,6 +107,7 @@ cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
"@llvm-project//mlir:CAPISparseTensorObjects",
"@org_tensorflow//tensorflow/compiler/mlir/hlo:CAPIObjects",
],
)