mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Don't build the Triton MLIR dialect on Windows
This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build. CI failures look like this: ``` C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64 b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n with\r\n [\r\n T=decay<FnT>::type\r\n ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or 'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n" output = subprocess.check_output(cmd) ``` PiperOrigin-RevId: 609153322
This commit is contained in:
parent
63a0f19941
commit
ef40b85c8b
@ -14,6 +14,7 @@
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"if_windows",
|
||||
"py_extension",
|
||||
"pybind_extension",
|
||||
"windows_cc_shared_mlir_library",
|
||||
@ -250,7 +251,6 @@ cc_library(
|
||||
name = "jaxlib_mlir_capi_objects",
|
||||
deps = [
|
||||
"//jaxlib/mosaic:tpu_dialect_capi_objects",
|
||||
"//jaxlib/triton:triton_dialect_capi_objects",
|
||||
"@llvm-project//mlir:CAPIArithObjects",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
"@llvm-project//mlir:CAPIMathObjects",
|
||||
@ -263,7 +263,12 @@ cc_library(
|
||||
"@stablehlo//:chlo_capi_objects",
|
||||
"@stablehlo//:stablehlo_capi_objects",
|
||||
"@xla//xla/mlir_hlo:CAPIObjects",
|
||||
],
|
||||
] + if_windows(
|
||||
[],
|
||||
[
|
||||
"//jaxlib/triton:triton_dialect_capi_objects",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
|
@ -327,11 +327,17 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
|
||||
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}",
|
||||
],
|
||||
]
|
||||
+ (
|
||||
[]
|
||||
if build_utils.is_windows()
|
||||
else [
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
|
||||
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
triton_dir = jaxlib_dir / "triton"
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:jax.bzl", "pytype_strict_library")
|
||||
load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library")
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
|
||||
|
||||
licenses(["notice"])
|
||||
@ -33,8 +33,10 @@ pytype_strict_library(
|
||||
deps = [
|
||||
"//jaxlib/mlir:core",
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir/_mlir_libs:_triton_ext",
|
||||
],
|
||||
] + if_windows(
|
||||
[],
|
||||
["//jaxlib/mlir/_mlir_libs:_triton_ext"],
|
||||
),
|
||||
)
|
||||
|
||||
genrule(
|
||||
|
Loading…
x
Reference in New Issue
Block a user