From ef40b85c8b2686f64bc9ca67de267a6b1a7935bb Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 21 Feb 2024 16:02:14 -0800 Subject: [PATCH] Don't build the Triton MLIR dialect on Windows This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build. CI failures look like this: ``` C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64 b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if::type,std::is_class::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n with\r\n [\r\n T=decay::type\r\n ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or 'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n" output = subprocess.check_output(cmd) ``` PiperOrigin-RevId: 609153322 --- jaxlib/mlir/_mlir_libs/BUILD.bazel | 9 +++++++-- jaxlib/tools/build_wheel.py | 12 +++++++++--- jaxlib/triton/BUILD | 8 +++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index c634c52e9..082e0f765 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -14,6 +14,7 @@ load( "//jaxlib:jax.bzl", + "if_windows", "py_extension", "pybind_extension", "windows_cc_shared_mlir_library", @@ -250,7 +251,6 @@ cc_library( name = "jaxlib_mlir_capi_objects", deps = [ "//jaxlib/mosaic:tpu_dialect_capi_objects", - "//jaxlib/triton:triton_dialect_capi_objects", "@llvm-project//mlir:CAPIArithObjects", "@llvm-project//mlir:CAPIIRObjects", "@llvm-project//mlir:CAPIMathObjects", @@ -263,7 +263,12 @@ cc_library( "@stablehlo//:chlo_capi_objects", "@stablehlo//:stablehlo_capi_objects", "@xla//xla/mlir_hlo:CAPIObjects", - ], + ] + if_windows( + [], + [ + "//jaxlib/triton:triton_dialect_capi_objects", + ], + ), ) cc_binary( diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 52a146642..4ef295c39 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -327,11 +327,17 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", - "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", - ], + ] + + ( + [] + if build_utils.is_windows() + else [ + f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", + "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", + ] + ), ) triton_dir = jaxlib_dir / "triton" diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 2f913b54a..ac2a43255 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "pytype_strict_library") +load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") licenses(["notice"]) @@ -33,8 +33,10 @@ pytype_strict_library( deps = [ "//jaxlib/mlir:core", "//jaxlib/mlir:ir", - "//jaxlib/mlir/_mlir_libs:_triton_ext", - ], + ] + if_windows( + [], + ["//jaxlib/mlir/_mlir_libs:_triton_ext"], + ), ) genrule(