The Triton MLIR bindings now include auto-generated wrappers for enums

PiperOrigin-RevId: 596873541
This commit is contained in:
Sergei Lebedev 2024-01-09 03:00:04 -08:00 committed by jax authors
parent df0f1e06e0
commit f219482212
2 changed files with 25 additions and 6 deletions

View File

@ -39,23 +39,41 @@ pytype_strict_library(
genrule(
name = "_triton_gen",
srcs = [":_triton_gen_raw"],
outs = ["_triton_gen.py"],
srcs = [
"_triton_ops_gen_raw.py",
"_triton_enum_gen_raw.py",
],
outs = [
"_triton_ops_gen.py",
"_triton_enum_gen.py",
],
cmd = """
echo '# pytype: skip-file' > $@ && \
cat $(location :_triton_gen_raw) | sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $@
for src in $(SRCS); do
out=$${src//_raw/}
echo '# pytype: skip-file' > $${out} && \
cat $${src} |
sed -e 's/^from \\.\\./from mlir\\./g' |
sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $${out}
done
""",
)
gentbl_filegroup(
name = "_triton_gen_raw",
tbl_outs = [
(
[
"-gen-python-enum-bindings",
"-bind-dialect=tt",
],
"_triton_enum_gen_raw.py",
),
(
[
"-gen-python-op-bindings",
"-bind-dialect=tt",
],
"_triton_gen_raw.py",
"_triton_ops_gen_raw.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",

View File

@ -16,5 +16,6 @@
"""Python bindings for the MLIR Triton dialect."""
from ._triton_enum_gen import * # pylint: disable=wildcard-import
from ._triton_ext import register_dialect, PointerType
from ._triton_gen import * # pylint: disable=wildcard-import
from ._triton_ops_gen import * # pylint: disable=wildcard-import