From c12cb0ccbb408c5e65801a6aa5a8e53b8b36c646 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 11 Apr 2025 08:55:38 -0400 Subject: [PATCH] [mlir][python] fix value-builder generation for snake_case ops (#135302) Ops that are already snake case (like [`ROCDL_wmma_*` ops](https://github.com/makslevental/llvm-project/blob/66b0b0466bbd995146aadaf2cd18de5476c19941/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td#L411)) produce python "value-builders" that collide with the class names: ```python class wmma_bf16_16x16x16_bf16(_ods_ir.OpView): OPERATION_NAME = "rocdl.wmma.bf16.16x16x16.bf16" ... def wmma_bf16_16x16x16_bf16(res, args, *, loc=None, ip=None) -> _ods_ir.Value: return wmma_bf16_16x16x16_bf16(res=res, args=args, loc=loc, ip=ip).result ``` and thus cannot be emitted (because of recursive self-calls). This PR fixes that by affixing `_` to the value builder names. I would've preferred to just rename the ops but that would be a breaking change :shrug:. --- mlir/test/mlir-tblgen/op-python-bindings.td | 6 +++++ mlir/test/python/dialects/rocdl.py | 25 ++++++++++++++++--- mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 2 ++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 72963cac64d5..c2bd86819666 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -654,3 +654,9 @@ def WithSuccessorsOp : TestOp<"with_successors"> { // CHECK: def with_successors(successor, successors, *, loc=None, ip=None) // CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip) + +// CHECK: class snake_case(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.snake_case" +def already_snake_case : TestOp<"snake_case"> {} +// CHECK: def snake_case_(*, loc=None, ip=None) +// CHECK: return snake_case(loc=loc, ip=ip) diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py index a4eca2766899..a4a50afa966c 100644 --- a/mlir/test/python/dialects/rocdl.py +++ b/mlir/test/python/dialects/rocdl.py @@ -1,8 +1,10 @@ # RUN: %PYTHON %s | FileCheck %s # This is just a smoke test that the dialect is functional. +from array import array from mlir.ir import * -from mlir.dialects import rocdl +from mlir.dialects import rocdl, arith +from mlir.extras import types as T def constructAndPrintInModule(f): @@ -18,5 +20,22 @@ def constructAndPrintInModule(f): # CHECK-LABEL: testSmoke @constructAndPrintInModule def testSmoke(): - # CHECK: rocdl.barrier - rocdl.BarrierOp() + v_len = 16 + f32 = F32Type.get() + # Note: this isn't actually the right type for the intrinsic (should be f16) + # but array doesn't support f16. + v16f32 = T.vector(v_len, f32) + f32_array = array("f", [0.0] * v_len) + a_frag = arith.constant(v16f32, f32_array) + b_frag = arith.constant(v16f32, f32_array) + c_frag = arith.constant(v16f32, f32_array) + false = arith.constant(T.bool(), False) + + c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false]) + # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16 + print(c_frag) + assert isinstance(c_frag, OpView) + # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16 + c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false]) + print(c_frag) + assert isinstance(c_frag, Value) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 604d2376052a..d2e38e9d2319 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -1000,6 +1000,8 @@ static void emitValueBuilder(const Operator &op, }); std::string nameWithoutDialect = sanitizeName( op.getOperationName().substr(op.getOperationName().find('.') + 1)); + if (nameWithoutDialect == op.getCppClassName()) + nameWithoutDialect += "_"; std::string params = llvm::join(valueBuilderParams, ", "); std::string args = llvm::join(opBuilderArgs, ", "); const char *type =