125 lines
3.2 KiB
Python
Raw Normal View History

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//jax:internal"],
)
pytype_strict_library(
name = "triton",
srcs = [
"__init__.py",
"dialect.py",
":_triton_gen",
],
visibility = ["//visibility:public"],
deps = [
"//jaxlib/mlir:core",
"//jaxlib/mlir:ir",
Don't build the Triton MLIR dialect on Windows This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build. CI failures look like this: ``` C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64 b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n with\r\n [\r\n T=decay<FnT>::type\r\n ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or 'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n" output = subprocess.check_output(cmd) ``` PiperOrigin-RevId: 609153322
2024-02-21 16:02:14 -08:00
] + if_windows(
[],
["//jaxlib/mlir/_mlir_libs:_triton_ext"],
),
)
genrule(
name = "_triton_gen",
srcs = [
"_triton_ops_gen_raw.py",
"_triton_enum_gen_raw.py",
],
outs = [
"_triton_ops_gen.py",
"_triton_enum_gen.py",
],
# Use $(RULEDIR) to avoid an implicit dependency on whether inputs are in bin or genfiles.
cmd = """
for src in $(SRCS); do
base=$$(basename $$src)
out=$(RULEDIR)/$${base//_raw/}
echo '# pytype: skip-file' > $${out} && \
cat $${src} |
sed -e 's/^from \\.\\./from jaxlib\\.mlir\\./g' |
sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' >> $${out}
done
""",
)
gentbl_filegroup(
name = "_triton_gen_raw",
tbl_outs = [
(
[
"-gen-python-enum-bindings",
"-bind-dialect=tt",
],
"_triton_enum_gen_raw.py",
),
(
[
"-gen-python-op-bindings",
"-bind-dialect=tt",
],
"_triton_ops_gen_raw.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "triton.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@triton//:td_files",
],
)
cc_library(
name = "triton_dialect_capi",
srcs = ["triton_dialect_capi.cc"],
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@triton//:TritonDialects",
],
)
# Header-only target, used when using the C API from a separate shared library.
cc_library(
name = "triton_dialect_capi_headers",
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# Alwayslink target, used when exporting the C API from a shared library.
cc_library(
name = "triton_dialect_capi_objects",
srcs = ["triton_dialect_capi.cc"],
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:IR",
"@triton//:TritonDialects",
],
alwayslink = True,
)