//===- Traits.cpp - Common op traits shared by dialects -------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/Traits.h" #include "mlir/IR/StandardTypes.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, ArrayRef shape2, SmallVectorImpl &resultShape) { // To compute the result broadcasted shape, we compare operand shapes // element-wise: starting with the trailing dimensions, and working the // way backward. Two dimensions are compatible when // 1. they are equal, or // 2. one of them is 1 // The result shape has the maximum among the two inputs at every // dimension index. resultShape.clear(); if (shape1.size() > shape2.size()) { std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); } else { std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); } auto i1 = shape1.rbegin(), e1 = shape1.rend(); auto i2 = shape2.rbegin(), e2 = shape2.rend(); auto iR = resultShape.rbegin(); // Check each dimension is consistent. for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { if (*i1 == -1 || *i2 == -1) { // One or both dimensions is unknown. Follow TensorFlow behavior: // - If either dimension is greater than 1, we assume that the program is // correct, and the other dimension will be broadcast to match it. // - If either dimension is 1, the other dimension is the output. if (*i1 > 1) { *iR = *i1; } else if (*i2 > 1) { *iR = *i2; } else if (*i1 == 1) { *iR = *i2; } else if (*i2 == 1) { *iR = *i1; } else { *iR = -1; } } else { if (*i1 == *i2 || *i2 == 1) { *iR = *i1; } else if (*i1 == 1) { *iR = *i2; } else { // This dimension of the two operand types is incompatible. resultShape.clear(); return false; } } } return true; } /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef getShape(Type type) { if (auto sType = type.dyn_cast()) return sType.getShape(); return {}; } /// Returns the result broadcast composition type from the two given types by /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two /// given types are not broadcast-compatible. Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { // Returns the scalar type out of the given type. auto getScalarType = [](Type type) -> Type { if (auto shapedType = type.dyn_cast()) return shapedType.getElementType(); return type; }; // Make sure underlying scalar type is the same. auto scalarType = getScalarType(type1); if (scalarType != getScalarType(type2)) return {}; // If one of the types is unranked tensor, then the other type shouldn't be // vector and the result should have unranked tensor type. if (type1.isa() || type2.isa()) { if (type1.isa() || type2.isa()) return {}; return UnrankedTensorType::get(scalarType); } // Returns the type kind if the given type is a vector or ranked tensor type. // Returns llvm::None otherwise. auto getCompositeTypeKind = [](Type type) -> llvm::Optional { if (type.isa() || type.isa()) return static_cast(type.getKind()); return llvm::None; }; // Make sure the composite type, if has, is consistent. auto compositeKind1 = getCompositeTypeKind(type1); auto compositeKind2 = getCompositeTypeKind(type2); llvm::Optional resultCompositeKind; if (compositeKind1 && compositeKind2) { // Disallow mixing vector and tensor. if (compositeKind1 != compositeKind2) return {}; resultCompositeKind = compositeKind1; } else if (compositeKind1) { resultCompositeKind = compositeKind1; } else if (compositeKind2) { resultCompositeKind = compositeKind2; } // Get the shape of each type. SmallVector resultShape; if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) return {}; // Compose the final broadcasted type if (resultCompositeKind == StandardTypes::Vector) return VectorType::get(resultShape, scalarType); if (resultCompositeKind == StandardTypes::RankedTensor) return RankedTensorType::get(resultShape, scalarType); return scalarType; } /// Returns true if the given types has both vector types and tensor types. static bool hasBothVectorAndTensorType(ArrayRef types) { return llvm::any_of(types, [](Type t) { return t.isa(); }) && llvm::any_of(types, [](Type t) { return t.isa(); }); } static bool areCompatibleShapes(ArrayRef shape1, ArrayRef shape2) { auto isCompatible = [](int64_t dim1, int64_t dim2) { return dim1 == dim2 || dim1 == -1 || dim2 == -1; }; if (shape1.size() != shape2.size()) return false; for (const auto &p : llvm::zip(shape1, shape2)) if (!isCompatible(std::get<0>(p), std::get<1>(p))) return false; return true; } LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { assert(op->getNumOperands() == 2 && "only support broadcast check on two operands"); assert(op->getNumResults() == 1 && "only support broadcast check on one result"); auto type1 = op->getOperand(0)->getType(); auto type2 = op->getOperand(1)->getType(); auto retType = op->getResult(0)->getType(); // We forbid broadcasting vector and tensor. if (hasBothVectorAndTensorType({type1, type2, retType})) return op->emitError("cannot broadcast vector with tensor"); if (retType.isa()) return success(); bool isUnranked1 = type1.isa(); bool isUnranked2 = type2.isa(); // If both operands are unranked, then all result shapes are possible. if (isUnranked1 && isUnranked2) return success(); // If one of the operands is unranked, then the known dimensions in the result // should be compatible with the other shaped operand. if (isUnranked1 || isUnranked2) { // Result should have higher rank than the shaped operand's rank and then // the result's trailing dimensions should be compatible with the operand // shape. ArrayRef shape = getShape(!isUnranked1 ? type1 : type2); ArrayRef actualSuffix = getShape(retType).take_back(shape.size()); if (!areCompatibleShapes(actualSuffix, shape)) return op->emitOpError() << "result type " << retType << " has shape incompatible with a ranked operand type"; return success(); } // If both operands are shaped, then the computed broadcasted shape should be // compatible with the result shape. SmallVector resultShape; if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) return op->emitOpError("operands don't have broadcast-compatible shapes"); if (!areCompatibleShapes(resultShape, getShape(retType))) return op->emitOpError() << "result type " << retType << " does not have shape compatible with the one " "computed from the operand types"; return success(); }