mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Adds a wrapper to sparse tensor dialect, as part of an
an initial prototype of an alternate JAX compilation path that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO together with sparse tensor types. PiperOrigin-RevId: 443438043
This commit is contained in:
parent
bef5e02816
commit
c1261ccd27
@ -38,6 +38,7 @@ py_binary(
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir:mhlo_dialect",
|
||||
"//jaxlib/mlir:func_dialect",
|
||||
"//jaxlib/mlir:sparse_tensor_dialect",
|
||||
"//jaxlib:pocketfft_flatbuffers_py",
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
|
@ -245,6 +245,8 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_ods_common.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_func_ops_ext.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_func_ops_gen.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/sparse_tensor.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/builtin.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/chlo.py"), dst_dir=mlir_dialects_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/dialects/mhlo.py"), dst_dir=mlir_dialects_dir)
|
||||
@ -255,16 +257,22 @@ def prepare_wheel(sources_path):
|
||||
if _is_windows():
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlir.pyd"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.pyd"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.pyd"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.pyd"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
|
||||
elif _is_mac():
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlir.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
|
||||
else:
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlir.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so"), dst_dir=mlir_libs_dir)
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
|
@ -16,4 +16,10 @@
|
||||
import jaxlib.mlir.dialects.builtin as builtin
|
||||
import jaxlib.mlir.dialects.chlo as chlo
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.func as func
|
||||
import jaxlib.mlir.dialects.func as func
|
||||
try:
|
||||
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# TODO(ajcbik,phawkins): make this unconditional when jaxlib > 0.3.7
|
||||
# is the minimum version.
|
||||
pass
|
||||
|
@ -81,6 +81,21 @@ symlink_inputs(
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "sparse_tensor_dialect",
|
||||
rule = py_library,
|
||||
symlinked_inputs = {"srcs": {"dialects": [
|
||||
"@llvm-project//mlir/python:SparseTensorOpsPyFiles",
|
||||
]}},
|
||||
deps = [
|
||||
":core",
|
||||
":ir",
|
||||
":mlir",
|
||||
"//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor",
|
||||
"//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "mhlo_dialect",
|
||||
rule = py_library,
|
||||
|
@ -43,6 +43,33 @@ py_extension(
|
||||
],
|
||||
)
|
||||
|
||||
py_extension(
|
||||
name = "_mlirDialectsSparseTensor",
|
||||
srcs = [
|
||||
"@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_extension(
|
||||
name = "_mlirSparseTensorPasses",
|
||||
srcs = [
|
||||
"@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp",
|
||||
],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"@llvm-project//mlir:CAPISparseTensorHeaders",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# MHLO Extensions
|
||||
##---------------------------------------------------------------------------##
|
||||
@ -80,6 +107,7 @@ cc_library(
|
||||
name = "jaxlib_mlir_capi_objects",
|
||||
deps = [
|
||||
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
|
||||
"@llvm-project//mlir:CAPISparseTensorObjects",
|
||||
"@org_tensorflow//tensorflow/compiler/mlir/hlo:CAPIObjects",
|
||||
],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user