//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/IR/Attributes.h" #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Types.h" using namespace mlir; using namespace mlir::detail; Attribute::Kind Attribute::getKind() const { return attr->kind; } bool Attribute::isOrContainsFunction() const { return attr->isOrContainsFunctionCache; } // Given an attribute that could refer to a function attribute in the remapping // table, walk it and rewrite it to use the mapped function. If it doesn't // refer to anything in the table, then it is returned unmodified. Attribute Attribute::remapFunctionAttrs( const llvm::DenseMap &remappingTable, MLIRContext *context) const { // Most attributes are trivially unrelated to function attributes, skip them // rapidly. if (!isOrContainsFunction()) return *this; // If we have a function attribute, remap it. if (auto fnAttr = this->dyn_cast()) { auto it = remappingTable.find(fnAttr); return it != remappingTable.end() ? it->second : *this; } // Otherwise, we must have an array attribute, remap the elements. auto arrayAttr = this->cast(); SmallVector remappedElts; bool anyChange = false; for (auto elt : arrayAttr.getValue()) { auto newElt = elt.remapFunctionAttrs(remappingTable, context); remappedElts.push_back(newElt); anyChange |= (elt != newElt); } if (!anyChange) return *this; return ArrayAttr::get(remappedElts, context); } /// NumericAttr Type NumericAttr::getType() const { if (auto boolAttr = dyn_cast()) return boolAttr.getType(); if (auto intAttr = dyn_cast()) return intAttr.getType(); if (auto floatAttr = dyn_cast()) return floatAttr.getType(); if (auto elemAttr = dyn_cast()) return elemAttr.getType(); llvm_unreachable("unhandled NumericAttr subclass"); } bool NumericAttr::kindof(Kind kind) { return BoolAttr::kindof(kind) || IntegerAttr::kindof(kind) || FloatAttr::kindof(kind) || ElementsAttr::kindof(kind); } /// BoolAttr bool BoolAttr::getValue() const { return static_cast(attr)->value; } Type BoolAttr::getType() const { return static_cast(attr)->type; } /// IntegerAttr APInt IntegerAttr::getValue() const { return static_cast(attr)->getValue(); } int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } Type IntegerAttr::getType() const { return static_cast(attr)->type; } /// FloatAttr APFloat FloatAttr::getValue() const { return static_cast(attr)->getValue(); } Type FloatAttr::getType() const { return static_cast(attr)->type; } double FloatAttr::getValueAsDouble() const { const auto &semantics = getType().cast().getFloatSemantics(); auto value = getValue(); bool losesInfo = false; // ignored if (&semantics != &APFloat::IEEEdouble()) { value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &losesInfo); } return value.convertToDouble(); } /// StringAttr StringRef StringAttr::getValue() const { return static_cast(attr)->value; } /// ArrayAttr ArrayRef ArrayAttr::getValue() const { return static_cast(attr)->value; } /// AffineMapAttr AffineMap AffineMapAttr::getValue() const { return static_cast(attr)->value; } /// IntegerSetAttr IntegerSet IntegerSetAttr::getValue() const { return static_cast(attr)->value; } /// TypeAttr Type TypeAttr::getValue() const { return static_cast(attr)->value; } /// FunctionAttr Function *FunctionAttr::getValue() const { return static_cast(attr)->value; } FunctionType FunctionAttr::getType() const { return getValue()->getType(); } /// ElementsAttr VectorOrTensorType ElementsAttr::getType() const { return static_cast(attr)->type; } /// Return the value at the given index. If index does not refer to a valid /// element, then a null attribute is returned. Attribute ElementsAttr::getValue(ArrayRef index) const { switch (getKind()) { case Attribute::Kind::SplatElements: return cast().getValue(); case Attribute::Kind::DenseFPElements: case Attribute::Kind::DenseIntElements: return cast().getValue(index); case Attribute::Kind::OpaqueElements: return cast().getValue(index); case Attribute::Kind::SparseElements: return cast().getValue(index); default: llvm_unreachable("unknown ElementsAttr kind"); } } /// SplatElementsAttr Attribute SplatElementsAttr::getValue() const { return static_cast(attr)->elt; } /// DenseElementsAttr /// Return the value at the given index. If index does not refer to a valid /// element, then a null attribute is returned. Attribute DenseElementsAttr::getValue(ArrayRef index) const { auto type = getType(); // Verify that the rank of the indices matches the held type. auto rank = type.getRank(); if (rank != index.size()) return Attribute(); // Verify that all of the indices are within the shape dimensions. auto shape = type.getShape(); for (unsigned i = 0; i != rank; ++i) if (shape[i] <= index[i]) return Attribute(); // Reduce the provided multidimensional index into a 1D index. uint64_t valueIndex = 0; uint64_t dimMultiplier = 1; for (auto i = rank - 1; i >= 0; --i) { valueIndex += index[i] * dimMultiplier; dimMultiplier *= shape[i]; } // Return the element stored at the 1D index. // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored // with double semantics. auto elementType = getType().getElementType(); size_t bitWidth = elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth(); APInt rawValueData = readBits(getRawData().data(), valueIndex * bitWidth, bitWidth); // Convert the raw value data to an attribute value. switch (getKind()) { case Attribute::Kind::DenseIntElements: return IntegerAttr::get(elementType, rawValueData); case Attribute::Kind::DenseFPElements: return FloatAttr::get( elementType, APFloat(elementType.cast().getFloatSemantics(), rawValueData)); default: llvm_unreachable("unexpected element type"); } } void DenseElementsAttr::getValues(SmallVectorImpl &values) const { auto elementType = getType().getElementType(); switch (getKind()) { case Attribute::Kind::DenseIntElements: { // Get the raw APInt values. SmallVector intValues; cast().getValues(intValues); // Convert each to an IntegerAttr. for (auto &intVal : intValues) values.push_back(IntegerAttr::get(elementType, intVal)); return; } case Attribute::Kind::DenseFPElements: { // Get the raw APFloat values. SmallVector floatValues; cast().getValues(floatValues); // Convert each to an FloatAttr. for (auto &floatVal : floatValues) values.push_back(FloatAttr::get(elementType, floatVal)); return; } default: llvm_unreachable("unexpected element type"); } } ArrayRef DenseElementsAttr::getRawData() const { return static_cast(attr)->data; } // Constructs a dense elements attribute from an array of raw APInt values. // Each APInt value is expected to have the same bitwidth as the element type // of 'type'. DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, ArrayRef values) { assert(values.size() == type.getNumElements() && "expected 'values' to contain the same number of elements as 'type'"); // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored // with double semantics. auto eltType = type.getElementType(); size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); std::vector elementData(APInt::getNumWords(bitWidth * values.size()) * APInt::APINT_WORD_SIZE); for (unsigned i = 0, e = values.size(); i != e; ++i) { assert(values[i].getBitWidth() == bitWidth); writeBits(elementData.data(), i * bitWidth, values[i]); } return get(type, elementData); } /// Parses the raw integer internal value for each dense element into /// 'values'. void DenseElementsAttr::getRawValues(SmallVectorImpl &values) const { auto elementType = getType().getElementType(); auto elementNum = getType().getNumElements(); values.reserve(elementNum); // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored // with double semantics. size_t bitWidth = elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth(); const auto *rawData = getRawData().data(); for (size_t i = 0, e = elementNum; i != e; ++i) values.push_back(readBits(rawData, i * bitWidth, bitWidth)); } /// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is /// expected to be a 64-bit aligned storage address. void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) { size_t bitWidth = value.getBitWidth(); // If the bitwidth is 1 we just toggle the specific bit. if (bitWidth == 1) { auto *rawIntData = reinterpret_cast(rawData); if (value.isOneValue()) APInt::tcSetBit(rawIntData, bitPos); else APInt::tcClearBit(rawIntData, bitPos); return; } // If the bit position and width are byte aligned, write the storage directly // to the data. if ((bitWidth % 8) == 0 && (bitPos % 8) == 0) { std::copy_n(reinterpret_cast(value.getRawData()), bitWidth / 8, rawData + (bitPos / 8)); return; } // Otherwise, convert the raw data into an APInt and insert the value at the // specified bit position. size_t totalWords = APInt::getNumWords((bitPos % 64) + bitWidth); llvm::MutableArrayRef rawIntData( reinterpret_cast(rawData) + (bitPos / 64), totalWords); APInt tempStorage(totalWords * 64, rawIntData); tempStorage.insertBits(value, bitPos % 64); // Copy the value back to the raw data. std::copy_n(tempStorage.getRawData(), rawIntData.size(), rawIntData.data()); } /// Reads the next `bitWidth` bits from the bit position `bitPos` in array /// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address. APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos, size_t bitWidth) { // Reinterpret the raw data as a uint64_t word array and extract the value // starting at 'bitPos'. APInt result(bitWidth, 0); const uint64_t *intData = reinterpret_cast(rawData); APInt::tcExtract(const_cast(result.getRawData()), result.getNumWords(), intData, bitWidth, bitPos); return result; } /// DenseIntElementsAttr /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, ArrayRef values) { return DenseElementsAttr::get(type, values).cast(); } /// Constructs a dense integer elements attribute from an array of integer /// values. Each value is expected to be within the bitwidth of the element /// type of 'type'. DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, ArrayRef values) { auto eltType = type.getElementType(); size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); // Convert the raw integer values to APInt. SmallVector apIntValues; apIntValues.reserve(values.size()); for (auto value : values) apIntValues.emplace_back(APInt(bitWidth, value)); return get(type, apIntValues); } void DenseIntElementsAttr::getValues(SmallVectorImpl &values) const { // Simply return the raw integer values. getRawValues(values); } /// DenseFPElementsAttr // Constructs a dense float elements attribute from an array of APFloat // values. Each APFloat value is expected to have the same bitwidth as the // element type of 'type'. DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type, ArrayRef values) { // Convert the APFloat values to APInt and create a dense elements attribute. std::vector intValues(values.size()); for (unsigned i = 0, e = values.size(); i != e; ++i) intValues[i] = values[i].bitcastToAPInt(); return DenseElementsAttr::get(type, intValues).cast(); } void DenseFPElementsAttr::getValues(SmallVectorImpl &values) const { // Get the raw APInt element values. SmallVector intValues; getRawValues(intValues); // Convert each of the APInt values to an APFloat. auto elementType = getType().getElementType().dyn_cast(); const auto &elementSemantics = elementType.getFloatSemantics(); for (auto &intValue : intValues) values.push_back(APFloat(elementSemantics, intValue)); } /// OpaqueElementsAttr StringRef OpaqueElementsAttr::getValue() const { return static_cast(attr)->bytes; } /// Return the value at the given index. If index does not refer to a valid /// element, then a null attribute is returned. Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { if (Dialect *dialect = getDialect()) return dialect->extractElementHook(*this, index); return Attribute(); } Dialect *OpaqueElementsAttr::getDialect() const { return static_cast(attr)->dialect; } bool OpaqueElementsAttr::decode(ElementsAttr &result) { if (auto *d = getDialect()) return d->decodeHook(*this, result); return true; } /// SparseElementsAttr DenseIntElementsAttr SparseElementsAttr::getIndices() const { return static_cast(attr)->indices; } DenseElementsAttr SparseElementsAttr::getValues() const { return static_cast(attr)->values; } /// Return the value of the element at the given index. Attribute SparseElementsAttr::getValue(ArrayRef index) const { auto type = getType(); // Verify that the rank of the indices matches the held type. auto rank = type.getRank(); if (rank != index.size()) return Attribute(); // The sparse indices are 64-bit integers, so we can reinterpret the raw data // as a 1-D index array. auto sparseIndices = getIndices(); const uint64_t *sparseIndexValues = reinterpret_cast(sparseIndices.getRawData().data()); // Build a mapping between known indices and the offset of the stored element. llvm::SmallDenseMap, size_t> mappedIndices; auto numSparseIndices = sparseIndices.getType().getDimSize(0); for (size_t i = 0, e = numSparseIndices; i != e; ++i) mappedIndices.try_emplace( {sparseIndexValues + (i * rank), static_cast(rank)}, i); // Look for the provided index key within the mapped indices. If the provided // index is not found, then return a zero attribute. auto it = mappedIndices.find(index); if (it == mappedIndices.end()) { auto eltType = type.getElementType(); if (eltType.isa()) return FloatAttr::get(eltType, 0); assert(eltType.isa() && "unexpected element type"); return IntegerAttr::get(eltType, 0); } // Otherwise, return the held sparse value element. return getValues().getValue(it->second); } /// NamedAttributeList NamedAttributeList::NamedAttributeList(MLIRContext *context, ArrayRef attributes) { setAttrs(context, attributes); } /// Return all of the attributes on this operation. ArrayRef NamedAttributeList::getAttrs() const { return attrs ? attrs->getElements() : llvm::None; } /// Replace the held attributes with ones provided in 'newAttrs'. void NamedAttributeList::setAttrs(MLIRContext *context, ArrayRef attributes) { // Don't create an attribute list if there are no attributes. if (attributes.empty()) { attrs = nullptr; return; } assert(llvm::all_of(attributes, [](const NamedAttribute &attr) { return attr.second; }) && "attributes cannot have null entries"); attrs = AttributeListStorage::get(attributes, context); } /// Return the specified attribute if present, null otherwise. Attribute NamedAttributeList::get(StringRef name) const { for (auto elt : getAttrs()) if (elt.first.is(name)) return elt.second; return nullptr; } Attribute NamedAttributeList::get(Identifier name) const { for (auto elt : getAttrs()) if (elt.first == name) return elt.second; return nullptr; } /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. void NamedAttributeList::set(MLIRContext *context, Identifier name, Attribute value) { assert(value && "attributes may never be null"); // If we already have this attribute, replace it. auto origAttrs = getAttrs(); SmallVector newAttrs(origAttrs.begin(), origAttrs.end()); for (auto &elt : newAttrs) if (elt.first == name) { elt.second = value; attrs = AttributeListStorage::get(newAttrs, context); return; } // Otherwise, add it. newAttrs.push_back({name, value}); attrs = AttributeListStorage::get(newAttrs, context); } /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. auto NamedAttributeList::remove(MLIRContext *context, Identifier name) -> RemoveResult { auto origAttrs = getAttrs(); for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { if (origAttrs[i].first == name) { SmallVector newAttrs; newAttrs.reserve(origAttrs.size() - 1); newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); attrs = AttributeListStorage::get(newAttrs, context); return RemoveResult::Removed; } } return RemoveResult::NotFound; }