Don't build the Triton MLIR dialect on Windows

This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build.

CI failures look like this:
```
C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64
b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n        with\r\n        [\r\n            T=decay<FnT>::type\r\n        ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or       'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n"
    output = subprocess.check_output(cmd)
```
PiperOrigin-RevId: 609153322
This commit is contained in:
Peter Hawkins 2024-02-21 16:02:14 -08:00 committed by jax authors
parent 63a0f19941
commit ef40b85c8b
3 changed files with 21 additions and 8 deletions

View File

@ -14,6 +14,7 @@
load(
"//jaxlib:jax.bzl",
"if_windows",
"py_extension",
"pybind_extension",
"windows_cc_shared_mlir_library",
@ -250,7 +251,6 @@ cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"//jaxlib/triton:triton_dialect_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPIMathObjects",
@ -263,7 +263,12 @@ cc_library(
"@stablehlo//:chlo_capi_objects",
"@stablehlo//:stablehlo_capi_objects",
"@xla//xla/mlir_hlo:CAPIObjects",
],
] + if_windows(
[],
[
"//jaxlib/triton:triton_dialect_capi_objects",
],
),
)
cc_binary(

View File

@ -327,11 +327,17 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}",
],
]
+ (
[]
if build_utils.is_windows()
else [
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
]
),
)
triton_dir = jaxlib_dir / "triton"

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "pytype_strict_library")
load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
licenses(["notice"])
@ -33,8 +33,10 @@ pytype_strict_library(
deps = [
"//jaxlib/mlir:core",
"//jaxlib/mlir:ir",
"//jaxlib/mlir/_mlir_libs:_triton_ext",
],
] + if_windows(
[],
["//jaxlib/mlir/_mlir_libs:_triton_ext"],
),
)
genrule(