diff --git a/mlir/g3doc/Dialects/Standard.md b/mlir/g3doc/Dialects/Standard.md index 70a6d4c4889b..9b1648ef5d11 100644 --- a/mlir/g3doc/Dialects/Standard.md +++ b/mlir/g3doc/Dialects/Standard.md @@ -452,15 +452,38 @@ Example: tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0> ``` +## Unary Operations + +### 'exp' operation + +Syntax: + +``` {.ebnf} +operation ::= ssa-id `=` `exp` ssa-use `:` type +``` + +Examples: + +```mlir {.mlir} +// Scalar natural exponential. +%a = exp %b : f64 + +// SIMD vector element-wise natural exponential. +%f = exp %g : vector<4xf32> + +// Tensor element-wise natural exponential. +%x = exp %y : tensor<4x?xf8> +``` + +The `exp` operation takes one operand and returns one result of the same type. +This type may be a float scalar type, a vector whose element type is float, or a +tensor of floats. It has no standard attributes. + ## Arithmetic Operations Basic arithmetic in MLIR is specified by standard operations described in this section. -TODO: "sub" etc. Let's not get excited about filling this out yet, we can define -these on demand. We should be highly informed by and learn from the operations -supported by HLO and LLVM. - ### 'addi' operation Syntax: @@ -478,7 +501,7 @@ Examples: // SIMD vector element-wise addition, e.g. for Intel SSE. %f = addi %g, %h : vector<4xi32> -// Tensor element-wise addition, analogous to HLO's add operation. +// Tensor element-wise addition. %x = addi %y, %z : tensor<4x?xi8> ``` @@ -504,7 +527,7 @@ Examples: // SIMD vector addition, e.g. for Intel SSE. %f = addf %g, %h : vector<4xf32> -// Tensor addition, analogous to HLO's add operation. +// Tensor addition. %x = addf %y, %z : tensor<4x?xbf16> ``` @@ -757,7 +780,7 @@ Examples: // SIMD pointwise vector multiplication, e.g. for Intel SSE. %f = mulf %g, %h : vector<4xf32> -// Tensor pointwise multiplication, analogous to HLO's pointwise multiply operation. +// Tensor pointwise multiplication. %x = mulf %y, %z : tensor<4x?xbf16> ``` diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 0e1f1a90d985..dd02f751382c 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -72,6 +72,27 @@ class CastOp traits = []> : let hasFolder = 1; } +// Base class for unary ops. Requires single operand and result. Individual +// classes will have `operand` accessor. +class UnaryOp traits = []> : + Op { + let results = (outs AnyType); + let printer = [{ + return printStandardUnaryOp(this->getOperation(), p); + }]; +} + +class UnaryOpSameOperandAndResultType traits = []> : + UnaryOp { + let parser = [{ + return impl::parseOneResultSameOperandTypeOp(parser, result); + }]; +} + +class FloatUnaryOp traits = []> : + UnaryOpSameOperandAndResultType, + Arguments<(ins FloatLike:$operand)>; + // Base class for standard arithmetic operations. Requires operands and // results to be of the same type, but does not constrain them to specific // types. Individual classes will have `lhs` and `rhs` accessor to operands. @@ -597,6 +618,10 @@ def DivIUOp : IntArithmeticOp<"diviu"> { let hasFolder = 1; } +def ExpOp : FloatUnaryOp<"exp"> { + let summary = "base-e exponential of the specified value"; +} + def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { let summary = "element extract operation"; let description = [{ diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index c500e7364fa4..65033b612df1 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1141,13 +1141,17 @@ private: Concept *impl; }; -// These functions are out-of-line implementations of the methods in BinaryOp, -// which avoids them being template instantiated/duplicated. +// These functions are out-of-line implementations of the methods in UnaryOp and +// BinaryOp, which avoids them being template instantiated/duplicated. namespace impl { +ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser, + OperationState &result); + void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs, Value *rhs); ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result); + // Prints the given binary `op` in custom assembly form if both the two operands // and the result have the same time. Otherwise, prints the generic assembly // form. diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 76de499592eb..206dde773e31 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -443,28 +443,43 @@ static SmallVector getCoordinates(ArrayRef basis, return res; } +template struct OpCountValidator { + static_assert( + std::is_base_of< + typename OpTrait::NOperands::template Impl, + SourceOp>::value, + "wrong operand count"); +}; + +template struct OpCountValidator { + static_assert(std::is_base_of, SourceOp>::value, + "expected a single operand"); +}; + +template void ValidateOpCount() { + OpCountValidator(); +} + // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect -// Ops for binary ops with one result. This supports higher-dimensional vector +// Ops for N-ary ops with one result. This supports higher-dimensional vector // types. -template -struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern { +template +struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - using Super = BinaryOpLLVMOpLowering; + using Super = NaryOpLLVMOpLowering; // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - static_assert( - std::is_base_of::Impl, SourceOp>::value, - "expected binary op"); + ValidateOpCount(); static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); static_assert(std::is_base_of, SourceOp>::value, - "expected single result op"); + "expected same operands and result type"); // Cannot convert ops if their operands are not of LLVM type. for (Value *operand : operands) { @@ -489,7 +504,7 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern { arraySizes.push_back(llvmTy.getArrayNumElements()); llvmTy = llvmTy.getArrayElementType(); } - assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type"); + assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type"); auto llvmVectorTy = llvmTy; // Iteratively extract a position coordinates with basis `arraySize` from a @@ -511,13 +526,13 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors - Value *extractedLHS = rewriter.create( - loc, llvmVectorTy, operands[0], position); - Value *extractedRHS = rewriter.create( - loc, llvmVectorTy, operands[1], position); + SmallVector extractedOperands; + for (unsigned i = 0; i < OpCount; ++i) { + extractedOperands.push_back(rewriter.create( + loc, llvmVectorTy, operands[i], position)); + } Value *newVal = rewriter.create( - loc, llvmVectorTy, ArrayRef{extractedLHS, extractedRHS}, - op->getAttrs()); + loc, llvmVectorTy, extractedOperands, op->getAttrs()); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); } @@ -526,8 +541,16 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern { } }; +template +using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering; +template +using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering; + // Specific lowerings. // FIXME: this should be tablegen'ed. +struct ExpOpLowering : public UnaryOpLLVMOpLowering { + using Super::Super; +}; struct AddIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; @@ -1301,18 +1324,49 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed + // clang-format off patterns.insert< - AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, - BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, - CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, - DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, - DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, - MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, - RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, - SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering, - SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering, - SubIOpLowering, TruncateIOpLowering, XOrOpLowering, + AddFOpLowering, + AddIOpLowering, + AllocOpLowering, + AndOpLowering, + BranchOpLowering, + CallIndirectOpLowering, + CallOpLowering, + CmpFOpLowering, + CmpIOpLowering, + CondBranchOpLowering, + ConstLLVMOpLowering, + DeallocOpLowering, + DimOpLowering, + DivFOpLowering, + DivISOpLowering, + DivIUOpLowering, + ExpOpLowering, + FPExtLowering, + FPTruncLowering, + FuncOpConversion, + IndexCastOpLowering, + LoadOpLowering, + MemRefCastOpLowering, + MulFOpLowering, + MulIOpLowering, + OrOpLowering, + RemFOpLowering, + RemISOpLowering, + RemIUOpLowering, + ReturnOpLowering, + SIToFPLowering, + SelectOpLowering, + SignExtendIOpLowering, + SplatOpLowering, + StoreOpLowering, + SubFOpLowering, + SubIOpLowering, + TruncateIOpLowering, + XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); + // clang-format on } // Convert types using the stored LLVM IR module. diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 5cbdb674ef2f..443aa64c5260 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -124,6 +124,19 @@ struct StdInlinerInterface : public DialectInlinerInterface { // StandardOpsDialect //===----------------------------------------------------------------------===// +/// A custom unary operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) { + assert(op->getNumOperands() == 1 && "unary op should have one operand"); + assert(op->getNumResults() == 1 && "unary op should have one result"); + + const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' + << *op->getOperand(0); + p.printOptionalAttrDict(op->getAttrs()); + p << " : " << op->getOperand(0)->getType(); +} + /// A custom binary operation printer that omits the "std." prefix from the /// operation names. static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { @@ -139,7 +152,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { return; } - p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' + const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1); p.printOptionalAttrDict(op->getAttrs()); @@ -150,7 +164,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { /// A custom cast operation printer that omits the "std." prefix from the /// operation names. static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { - p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' + const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 5fd51bd1f06f..fb23a76cf255 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -421,6 +421,8 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) { %12 = or %arg2, %arg3 : i32 // CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32 %13 = xor %arg2, %arg3 : i32 +// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float + %14 = std.exp %arg0 : f32 return %0, %4 : f32, i32 } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index abb731d25d63..417068a7facf 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -351,6 +351,14 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: = fptrunc {{.*}} : f32 to f16 %95 = fptrunc %f : f32 to f16 + // CHECK: %{{[0-9]+}} = exp %arg1 : f32 + %96 = "std.exp"(%f) : (f32) -> f32 + + // CHECK: %{{[0-9]+}} = exp %arg1 : f32 + %97 = exp %f : f32 + + // CHECK: %{{[0-9]+}} = exp %arg0 : tensor<4x4x?xf32> + %98 = exp %t : tensor<4x4x?xf32> return }