mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-14 15:46:32 +00:00
[mlir][python] fix value-builder generation for snake_case ops (#135302)
Ops that are already snake case (like [`ROCDL_wmma_*`
ops](66b0b0466b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (L411)
))
produce python "value-builders" that collide with the class names:
```python
class wmma_bf16_16x16x16_bf16(_ods_ir.OpView):
OPERATION_NAME = "rocdl.wmma.bf16.16x16x16.bf16"
...
def wmma_bf16_16x16x16_bf16(res, args, *, loc=None, ip=None) -> _ods_ir.Value:
return wmma_bf16_16x16x16_bf16(res=res, args=args, loc=loc, ip=ip).result
```
and thus cannot be emitted (because of recursive self-calls).
This PR fixes that by affixing `_` to the value builder names.
I would've preferred to just rename the ops but that would be a breaking
change 🤷.
This commit is contained in:
parent
dda53bef35
commit
c12cb0ccbb
@ -654,3 +654,9 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
|
||||
|
||||
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
|
||||
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
|
||||
|
||||
// CHECK: class snake_case(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.snake_case"
|
||||
def already_snake_case : TestOp<"snake_case"> {}
|
||||
// CHECK: def snake_case_(*, loc=None, ip=None)
|
||||
// CHECK: return snake_case(loc=loc, ip=ip)
|
||||
|
@ -1,8 +1,10 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This is just a smoke test that the dialect is functional.
|
||||
from array import array
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import rocdl
|
||||
from mlir.dialects import rocdl, arith
|
||||
from mlir.extras import types as T
|
||||
|
||||
|
||||
def constructAndPrintInModule(f):
|
||||
@ -18,5 +20,22 @@ def constructAndPrintInModule(f):
|
||||
# CHECK-LABEL: testSmoke
|
||||
@constructAndPrintInModule
|
||||
def testSmoke():
|
||||
# CHECK: rocdl.barrier
|
||||
rocdl.BarrierOp()
|
||||
v_len = 16
|
||||
f32 = F32Type.get()
|
||||
# Note: this isn't actually the right type for the intrinsic (should be f16)
|
||||
# but array doesn't support f16.
|
||||
v16f32 = T.vector(v_len, f32)
|
||||
f32_array = array("f", [0.0] * v_len)
|
||||
a_frag = arith.constant(v16f32, f32_array)
|
||||
b_frag = arith.constant(v16f32, f32_array)
|
||||
c_frag = arith.constant(v16f32, f32_array)
|
||||
false = arith.constant(T.bool(), False)
|
||||
|
||||
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false])
|
||||
# CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16
|
||||
print(c_frag)
|
||||
assert isinstance(c_frag, OpView)
|
||||
# CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16
|
||||
c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false])
|
||||
print(c_frag)
|
||||
assert isinstance(c_frag, Value)
|
||||
|
@ -1000,6 +1000,8 @@ static void emitValueBuilder(const Operator &op,
|
||||
});
|
||||
std::string nameWithoutDialect = sanitizeName(
|
||||
op.getOperationName().substr(op.getOperationName().find('.') + 1));
|
||||
if (nameWithoutDialect == op.getCppClassName())
|
||||
nameWithoutDialect += "_";
|
||||
std::string params = llvm::join(valueBuilderParams, ", ");
|
||||
std::string args = llvm::join(opBuilderArgs, ", ");
|
||||
const char *type =
|
||||
|
Loading…
x
Reference in New Issue
Block a user