[mlir] DenseStringElementsAttr added to default attribute types

Summary:
Implemented a DenseStringsElements attr for handling arrays / tensors of strings. This includes the
necessary logic for parsing and printing the attribute from MLIR's text format.

To store the attribute we perform a single allocation that includes all wrapped string data tightly packed.
This means no padding characters and no null terminators (as they could be present in the string). This
buffer includes a first chunk of data that represents an array of StringRefs, that contain address pointers
into the string data, with the length of each string wrapped. At this point there is no Sparse representation
however strings are not typically represented sparsely.

Differential Revision: https://reviews.llvm.org/D78600
This commit is contained in:
Rob Suderman 2020-04-23 19:01:51 -07:00 committed by River Riddle
parent 2c3ee8812c
commit 5b89c1dd68
11 changed files with 480 additions and 95 deletions

View File

@ -40,7 +40,8 @@ struct SymbolRefAttributeStorage;
struct TypeAttributeStorage;
/// Elements Attributes.
struct DenseElementsAttributeStorage;
struct DenseIntOrFPElementsAttributeStorage;
struct DenseStringElementsAttributeStorage;
struct OpaqueElementsAttributeStorage;
struct SparseElementsAttributeStorage;
} // namespace detail
@ -141,10 +142,11 @@ enum Kind {
Unit,
/// Elements Attributes.
DenseElements,
DenseIntOrFPElements,
DenseStringElements,
OpaqueElements,
SparseElements,
FIRST_ELEMENTS_ATTR = DenseElements,
FIRST_ELEMENTS_ATTR = DenseIntOrFPElements,
LAST_ELEMENTS_ATTR = SparseElements,
/// Locations.
@ -671,15 +673,14 @@ protected:
/// An attribute that represents a reference to a dense vector or tensor object.
///
class DenseElementsAttr
: public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
detail::DenseElementsAttributeStorage> {
class DenseElementsAttr : public ElementsAttr {
public:
using Base::Base;
using ElementsAttr::ElementsAttr;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() == StandardAttributes::DenseElements;
return attr.getKind() == StandardAttributes::DenseIntOrFPElements ||
attr.getKind() == StandardAttributes::DenseStringElements;
}
/// Constructs a dense elements attribute from an array of element values.
@ -712,6 +713,10 @@ public:
/// Overload of the above 'get' method that is specialized for boolean values.
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
/// Overload of the above 'get' method that is specialized for StringRef
/// values.
static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values);
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
@ -882,6 +887,14 @@ public:
ElementIterator<T>(rawData, splat, getNumElements())};
}
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
auto stringRefs = getRawStringData();
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
bool splat = isSplat();
return {ElementIterator<StringRef>(ptr, splat, 0),
ElementIterator<StringRef>(ptr, splat, getNumElements())};
}
/// Return the held element values as a range of Attributes.
llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
template <typename T, typename = typename std::enable_if<
@ -942,6 +955,9 @@ public:
/// form the user might expect.
ArrayRef<char> getRawData() const;
/// Return the raw StringRef data held by this attribute.
ArrayRef<StringRef> getRawStringData() const;
//===--------------------------------------------------------------------===//
// Mutation Utilities
//===--------------------------------------------------------------------===//
@ -973,6 +989,60 @@ protected:
return IntElementIterator(*this, getNumElements());
}
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type. This method is used to verify type
/// invariants that the templatized 'get' method cannot.
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
int64_t dataEltSize, bool isInt,
bool isSigned);
/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
};
/// An attribute class for representing dense arrays of strings. The structure
/// storing and querying a list of densely packed strings.
class DenseStringElementsAttr
: public Attribute::AttrBase<DenseStringElementsAttr, DenseElementsAttr,
detail::DenseStringElementsAttributeStorage> {
public:
using Base::Base;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::DenseStringElements;
}
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type. This method is used to verify type
/// invariants that the templatized 'get' method cannot.
static DenseStringElementsAttr get(ShapedType type, ArrayRef<StringRef> data);
protected:
friend DenseElementsAttr;
};
/// An attribute class for specializing behavior of Int and Floating-point
/// densely packed string arrays.
class DenseIntOrFPElementsAttr
: public Attribute::AttrBase<DenseIntOrFPElementsAttr, DenseElementsAttr,
detail::DenseIntOrFPElementsAttributeStorage> {
public:
using Base::Base;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::DenseIntOrFPElements;
}
protected:
friend DenseElementsAttr;
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'. 'type' must be a vector or tensor with static shape.
@ -990,20 +1060,15 @@ protected:
ArrayRef<char> data,
int64_t dataEltSize, bool isInt,
bool isSigned);
/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
};
/// An attribute that represents a reference to a dense float vector or tensor
/// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr {
class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
public:
using iterator = DenseElementsAttr::FloatElementIterator;
using DenseElementsAttr::DenseElementsAttr;
using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
/// Get an instance of a DenseFPElementsAttr with the given arguments. This
/// simply wraps the DenseElementsAttr::get calls.
@ -1035,13 +1100,13 @@ public:
/// An attribute that represents a reference to a dense integer vector or tensor
/// object.
class DenseIntElementsAttr : public DenseElementsAttr {
class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
public:
/// DenseIntElementsAttr iterates on APInt, so we can use the raw element
/// iterator directly.
using iterator = DenseElementsAttr::IntElementIterator;
using DenseElementsAttr::DenseElementsAttr;
using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
/// Get an instance of a DenseIntElementsAttr with the given arguments. This
/// simply wraps the DenseElementsAttr::get calls.
@ -1266,7 +1331,7 @@ class ElementsAttrIterator
typename... Args>
RetT process(Args &... args) const {
switch (attrKind) {
case StandardAttributes::DenseElements:
case StandardAttributes::DenseIntOrFPElements:
return ProcessFn<DenseIteratorT>()(args...);
case StandardAttributes::SparseElements:
return ProcessFn<SparseIteratorT>()(args...);

View File

@ -1307,6 +1307,16 @@ class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
class RankedF32ElementsAttr<list<int> dims> : RankedFloatElementsAttr<32, dims>;
class RankedF64ElementsAttr<list<int> dims> : RankedFloatElementsAttr<64, dims>;
def StringElementsAttr : ElementsAttrBase<
CPred<"$_self.isa<DenseStringElementsAttr>()" >,
"string elements attribute"> {
let storageType = [{ DenseElementsAttr }];
let returnType = [{ DenseElementsAttr }];
let convertFromStorage = "$_self";
}
// Base class for array attributes.
class ArrayAttrBase<Pred condition, string description> :
Attr<condition, description> {

View File

@ -1316,7 +1316,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
<< opType << ") does not match value type (" << valueType << ")";
return success();
} break;
case StandardAttributes::DenseElements:
case StandardAttributes::DenseIntOrFPElements:
case StandardAttributes::SparseElements: {
if (valueType == opType)
break;

View File

@ -973,6 +973,9 @@ protected:
/// used instead of individual elements when the elements attr is large.
void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
/// Print a dense string elements attribute.
void printDenseStringElementsAttr(DenseStringElementsAttr attr);
void printDialectAttribute(Attribute attr);
void printDialectType(Type type);
@ -1392,7 +1395,7 @@ void ModulePrinter::printAttribute(Attribute attr,
os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
break;
}
case StandardAttributes::DenseElements: {
case StandardAttributes::DenseIntOrFPElements: {
auto eltsAttr = attr.cast<DenseElementsAttr>();
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
printElidedElementsAttr(os);
@ -1403,6 +1406,17 @@ void ModulePrinter::printAttribute(Attribute attr,
os << '>';
break;
}
case StandardAttributes::DenseStringElements: {
auto eltsAttr = attr.cast<DenseStringElementsAttr>();
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
printElidedElementsAttr(os);
break;
}
os << "dense<";
printDenseStringElementsAttr(eltsAttr);
os << '>';
break;
}
case StandardAttributes::SparseElements: {
auto elementsAttr = attr.cast<SparseElementsAttr>();
if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
@ -1454,6 +1468,13 @@ static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
printFloatValue(value, os);
}
static void printDenseStringElement(DenseStringElementsAttr attr,
raw_ostream &os, unsigned index) {
os << "\"";
printEscapedString(attr.getRawStringData()[index], os);
os << "\"";
}
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
auto type = attr.getType();
@ -1526,6 +1547,63 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
os << ']';
}
void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
// Special case for 0-d and splat tensors.
if (attr.isSplat()) {
printDenseStringElement(attr, os, 0);
return;
}
// Special case for degenerate tensors.
auto numElements = type.getNumElements();
if (numElements == 0) {
for (int i = 0; i < rank; ++i)
os << '[';
for (int i = 0; i < rank; ++i)
os << ']';
return;
}
// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
// The mixed-radix counter, with radices in 'shape'.
SmallVector<unsigned, 4> counter(rank, 0);
// The number of brackets that have been opened and not closed.
unsigned openBrackets = 0;
auto bumpCounter = [&]() {
// Bump the least significant digit.
++counter[rank - 1];
// Iterate backwards bubbling back the increment.
for (unsigned i = rank - 1; i > 0; --i)
if (counter[i] >= shape[i]) {
// Index 'i' is rolled over. Bump (i-1) and close a bracket.
counter[i] = 0;
++counter[i - 1];
--openBrackets;
os << ']';
}
};
for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
if (idx != 0)
os << ", ";
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
printDenseStringElement(attr, os, idx);
bumpCounter();
}
while (openBrackets-- > 0)
os << ']';
}
void ModulePrinter::printType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";

View File

@ -385,6 +385,20 @@ inline size_t getDenseElementBitWidth(Type eltType) {
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public AttributeStorage {
public:
DenseElementsAttributeStorage(ShapedType ty, bool isSplat)
: AttributeStorage(ty), isSplat(isSplat) {}
bool isSplat;
};
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseIntOrFPElementsAttributeStorage
: public DenseElementsAttributeStorage {
DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
bool isSplat = false)
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}
struct KeyTy {
KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
bool isSplat = false)
@ -403,10 +417,6 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
bool isSplat;
};
DenseElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
bool isSplat = false)
: AttributeStorage(ty), data(data), isSplat(isSplat) {}
/// Compare this storage instance with the provided key.
bool operator==(const KeyTy &key) const {
if (key.type != getType())
@ -512,7 +522,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
}
/// Construct a new storage instance.
static DenseElementsAttributeStorage *
static DenseIntOrFPElementsAttributeStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
// If the data buffer is non-empty, we copy it into the allocator with a
// 64-bit alignment.
@ -528,12 +538,129 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
copy = ArrayRef<char>(rawData, data.size());
}
return new (allocator.allocate<DenseElementsAttributeStorage>())
DenseElementsAttributeStorage(key.type, copy, key.isSplat);
return new (allocator.allocate<DenseIntOrFPElementsAttributeStorage>())
DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat);
}
ArrayRef<char> data;
bool isSplat;
};
/// An attribute representing a reference to a dense vector or tensor object
/// containing strings.
struct DenseStringElementsAttributeStorage
: public DenseElementsAttributeStorage {
DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef<StringRef> data,
bool isSplat = false)
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}
struct KeyTy {
KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
bool isSplat = false)
: type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
/// The type of the dense elements.
ShapedType type;
/// The raw buffer for the data storage.
ArrayRef<StringRef> data;
/// The computed hash code for the storage data.
llvm::hash_code hashCode;
/// A boolean that indicates if this data is a splat or not.
bool isSplat;
};
/// Compare this storage instance with the provided key.
bool operator==(const KeyTy &key) const {
if (key.type != getType())
return false;
// Otherwise, we can default to just checking the data. StringRefs compare
// by contents.
return key.data == data;
}
/// Construct a key from a shaped type, StringRef data buffer, and a flag that
/// signals if the data is already known to be a splat. Callers to this
/// function are expected to tag preknown splat values when possible, e.g. one
/// element shapes.
static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
bool isKnownSplat) {
// Handle an empty storage instance.
if (data.empty())
return KeyTy(ty, data, 0);
// If the data is already known to be a splat, the key hash value is
// directly the data buffer.
if (isKnownSplat)
return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
// Handle the simple case of only one element.
size_t numElements = ty.getNumElements();
assert(numElements != 1 && "splat of 1 element should already be detected");
// Create the initial hash value with just the first element.
const auto &firstElt = data.front();
auto hashVal = llvm::hash_value(firstElt);
// Check to see if this storage represents a splat. If it doesn't then
// combine the hash for the data starting with the first non splat element.
for (size_t i = 1, e = data.size(); i != e; i++)
if (!firstElt.equals(data[i]))
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
// Otherwise, this is a splat so just return the hash of the first element.
return KeyTy(ty, {firstElt}, hashVal, /*isSplat=*/true);
}
/// Hash the key for the storage.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key.type, key.hashCode);
}
/// Construct a new storage instance.
static DenseStringElementsAttributeStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
// If the data buffer is non-empty, we copy it into the allocator with a
// 64-bit alignment.
ArrayRef<StringRef> copy, data = key.data;
if (data.empty()) {
return new (allocator.allocate<DenseStringElementsAttributeStorage>())
DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
}
int numEntries = key.isSplat ? 1 : data.size();
// Compute the amount data needed to store the ArrayRef and StringRef
// contents.
size_t dataSize = sizeof(StringRef) * numEntries;
for (int i = 0; i < numEntries; i++)
dataSize += data[i].size();
char *rawData = reinterpret_cast<char *>(
allocator.allocate(dataSize, alignof(uint64_t)));
// Setup a mutable array ref of our string refs so that we can update their
// contents.
auto mutableCopy = MutableArrayRef<StringRef>(
reinterpret_cast<StringRef *>(rawData), numEntries);
auto stringData = rawData + numEntries * sizeof(StringRef);
for (int i = 0; i < numEntries; i++) {
memcpy(stringData, data[i].data(), data[i].size());
mutableCopy[i] = StringRef(stringData, data[i].size());
stringData += data[i].size();
}
copy =
ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
return new (allocator.allocate<DenseStringElementsAttributeStorage>())
DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
}
ArrayRef<StringRef> data;
};
/// An attribute representing a reference to a tensor constant with opaque

View File

@ -411,7 +411,7 @@ int64_t ElementsAttr::getNumElements() const {
/// element, then a null attribute is returned.
Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
switch (getKind()) {
case StandardAttributes::DenseElements:
case StandardAttributes::DenseIntOrFPElements:
return cast<DenseElementsAttr>().getValue(index);
case StandardAttributes::OpaqueElements:
return cast<OpaqueElementsAttr>().getValue(index);
@ -442,7 +442,7 @@ ElementsAttr
ElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APInt &)> mapping) const {
switch (getKind()) {
case StandardAttributes::DenseElements:
case StandardAttributes::DenseIntOrFPElements:
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
default:
llvm_unreachable("unsupported ElementsAttr subtype");
@ -453,7 +453,7 @@ ElementsAttr
ElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APFloat &)> mapping) const {
switch (getKind()) {
case StandardAttributes::DenseElements:
case StandardAttributes::DenseIntOrFPElements:
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
default:
llvm_unreachable("unsupported ElementsAttr subtype");
@ -643,7 +643,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
"expected value to have same bitwidth as element type");
writeBits(data.data(), i * storageBitWidth, intVal);
}
return getRaw(type, data, /*isSplat=*/(values.size() == 1));
return DenseIntOrFPElementsAttr::getRaw(type, data,
/*isSplat=*/(values.size() == 1));
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@ -654,7 +655,14 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
for (int i = 0, e = values.size(); i != e; ++i)
setBit(buff.data(), i, values[i]);
return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
return DenseIntOrFPElementsAttr::getRaw(type, buff,
/*isSplat=*/(values.size() == 1));
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<StringRef> values) {
assert(!type.getElementType().isIntOrFloat());
return DenseStringElementsAttr::get(type, values);
}
/// Constructs a dense integer elements attribute from an array of APInt
@ -663,7 +671,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APInt> values) {
assert(type.getElementType().isIntOrIndex());
return getRaw(type, values);
return DenseIntOrFPElementsAttr::getRaw(type, values);
}
// Constructs a dense float elements attribute from an array of APFloat
@ -677,7 +685,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
std::vector<APInt> intValues(values.size());
for (unsigned i = 0, e = values.size(); i != e; ++i)
intValues[i] = values[i].bitcastToAPInt();
return getRaw(type, intValues);
return DenseIntOrFPElementsAttr::getRaw(type, intValues);
}
/// Construct a dense elements attribute from a raw buffer representing the
@ -686,34 +694,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
ArrayRef<char> rawBuffer,
bool isSplatBuffer) {
return getRaw(type, rawBuffer, isSplatBuffer);
}
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'.
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
ArrayRef<APInt> values) {
assert(hasSameElementsOrSplat(type, values));
size_t bitWidth = getDenseElementBitWidth(type.getElementType());
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
values.size());
for (unsigned i = 0, e = values.size(); i != e; ++i) {
assert(values[i].getBitWidth() == bitWidth);
writeBits(elementData.data(), i * storageBitWidth, values[i]);
}
return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
}
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data, bool isSplat) {
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
data, isSplat);
return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
}
/// Check the information for a C++ data type, check if this type is valid for
@ -743,19 +724,14 @@ static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
return intType.isSigned() ? isSigned : !isSigned;
}
/// Overload of the 'getRaw' method that asserts that the given type is of
/// integer type. This method is used to verify type invariants that the
/// templatized 'get' method cannot.
/// Defaults down the subclass implementation.
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
int64_t dataEltSize,
bool isInt,
bool isSigned) {
assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
int64_t numElements = data.size() / dataEltSize;
assert(numElements == 1 || numElements == type.getNumElements());
return getRaw(type, data, /*isSplat=*/numElements == 1);
return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
isInt, isSigned);
}
/// A method used to verify specific type invariants that the templatized 'get'
@ -767,7 +743,9 @@ bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
/// Returns if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
bool DenseElementsAttr::isSplat() const {
return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
}
/// Return the held element values as a range of Attributes.
auto DenseElementsAttr::getAttributeValues() const
@ -827,7 +805,11 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
/// Return the raw storage data held by this attribute.
ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(impl)->data;
return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
}
ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
}
/// Return a new DenseElementsAttr that has the same data as the current
@ -843,7 +825,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
"expected the same element type");
assert(newType.getNumElements() == curType.getNumElements() &&
"expected the same number of elements");
return getRaw(newType, getRawData(), isSplat());
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
}
DenseElementsAttr
@ -857,6 +839,63 @@ DenseElementsAttr DenseElementsAttr::mapValues(
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
}
//===----------------------------------------------------------------------===//
// DenseStringElementsAttr
//===----------------------------------------------------------------------===//
DenseStringElementsAttr
DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
type, values, (values.size() == 1));
}
//===----------------------------------------------------------------------===//
// DenseIntOrFPElementsAttr
//===----------------------------------------------------------------------===//
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'.
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
ArrayRef<APInt> values) {
assert(hasSameElementsOrSplat(type, values));
size_t bitWidth = getDenseElementBitWidth(type.getElementType());
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
values.size());
for (unsigned i = 0, e = values.size(); i != e; ++i) {
assert(values[i].getBitWidth() == bitWidth);
writeBits(elementData.data(), i * storageBitWidth, values[i]);
}
return DenseIntOrFPElementsAttr::getRaw(type, elementData,
/*isSplat=*/(values.size() == 1));
}
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data,
bool isSplat) {
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
type, data, isSplat);
}
/// Overload of the 'getRaw' method that asserts that the given type is of
/// integer type. This method is used to verify type invariants that the
/// templatized 'get' method cannot.
DenseElementsAttr
DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
int64_t dataEltSize, bool isInt,
bool isSigned) {
assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
int64_t numElements = data.size() / dataEltSize;
assert(numElements == 1 || numElements == type.getNumElements());
return getRaw(type, data, /*isSplat=*/numElements == 1);
}
//===----------------------------------------------------------------------===//
// DenseFPElementsAttr
//===----------------------------------------------------------------------===//

View File

@ -82,10 +82,11 @@ namespace {
/// the IR.
struct BuiltinDialect : public Dialect {
BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseElementsAttr,
DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
UnitAttr>();
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
UnknownLoc>();

View File

@ -1953,7 +1953,7 @@ public:
ArrayRef<int64_t> getShape() const { return shape; }
private:
enum class ElementKind { Boolean, Integer, Float };
enum class ElementKind { Boolean, Integer, Float, String };
/// Return a string to represent the given element kind.
const char *getElementKindStr(ElementKind kind) {
@ -1964,6 +1964,8 @@ private:
return "'integer'";
case ElementKind::Float:
return "'float'";
case ElementKind::String:
return "'string'";
}
llvm_unreachable("unknown element kind");
}
@ -1975,6 +1977,9 @@ private:
DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
FloatType eltTy);
/// Build a Dense String attribute for the given type.
DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
/// Build a Dense attribute with hex data for the given type.
DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
@ -2030,8 +2035,10 @@ ParseResult TensorLiteralParser::parse(bool allowHex) {
/// shaped type.
DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
ShapedType type) {
// Check to see if we parsed the literal from a hex string.
if (hexStorage.hasValue())
Type eltType = type.getElementType();
// Check to see if we parse the literal from a hex string.
if (hexStorage.hasValue() && eltType.isIntOrFloat())
return getHexAttr(loc, type);
// Check that the parsed storage size has the same number of elements to the
@ -2044,20 +2051,17 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
// If the type is an integer, build a set of APInt values from the storage
// with the correct bitwidth.
Type eltType = type.getElementType();
if (auto intTy = eltType.dyn_cast<IntegerType>())
return getIntAttr(loc, type, intTy);
if (auto indexTy = eltType.dyn_cast<IndexType>())
return getIntAttr(loc, type, indexTy);
// Otherwise, this must be a floating point type.
auto floatTy = eltType.dyn_cast<FloatType>();
if (!floatTy) {
p.emitError(loc) << "expected floating-point or integer element type, got "
<< eltType;
return nullptr;
}
return getFloatAttr(loc, type, floatTy);
// If parsing a floating point type.
if (auto floatTy = eltType.dyn_cast<FloatType>())
return getFloatAttr(loc, type, floatTy);
// Other types are assumed to be string representations.
return getStringAttr(loc, type, type.getElementType());
}
/// Build a Dense Integer attribute for the given type.
@ -2163,6 +2167,28 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
return DenseElementsAttr::get(type, floatValues);
}
/// Build a Dense String attribute for the given type.
DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
ShapedType type,
Type eltTy) {
if (hexStorage.hasValue()) {
auto stringValue = hexStorage.getValue().getStringValue();
return DenseStringElementsAttr::get(type, {stringValue});
}
std::vector<std::string> stringValues;
std::vector<StringRef> stringRefValues;
stringValues.reserve(storage.size());
stringRefValues.reserve(storage.size());
for (auto val : storage) {
stringValues.push_back(val.second.getStringValue());
stringRefValues.push_back(stringValues.back());
}
return DenseStringElementsAttr::get(type, stringRefValues);
}
/// Build a Dense attribute with hex data for the given type.
DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
ShapedType type) {
@ -2214,6 +2240,10 @@ ParseResult TensorLiteralParser::parseElement() {
p.consumeToken();
break;
case Token::string:
storage.emplace_back(/*isNegative=*/ false, p.getToken());
p.consumeToken();
break;
default:
return p.emitError("expected element literal of primitive type");
}

View File

@ -390,6 +390,40 @@ func @correct_type_pass() {
// -----
//===----------------------------------------------------------------------===//
// Test StringElementsAttr
//===----------------------------------------------------------------------===//
func @simple_scalar_example() {
"test.string_elements_attr"() {
// CHECK: dense<"example">
scalar_string_attr = dense<"example"> : tensor<2x!unknown<"">>
} : () -> ()
return
}
// -----
func @escape_string_example() {
"test.string_elements_attr"() {
// CHECK: dense<"new\0Aline">
scalar_string_attr = dense<"new\nline"> : tensor<2x!unknown<"">>
} : () -> ()
return
}
// -----
func @simple_scalar_example() {
"test.string_elements_attr"() {
// CHECK: dense<["example1", "example2"]>
scalar_string_attr = dense<["example1", "example2"]> : tensor<2x!unknown<"">>
} : () -> ()
return
}
// -----
//===----------------------------------------------------------------------===//
// Test SymbolRefAttr
//===----------------------------------------------------------------------===//

View File

@ -22,10 +22,5 @@
// -----
// expected-error@+1 {{expected floating-point or integer element type, got '!unknown<"">'}}
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2x!unknown<"">>} : () -> ()
// -----
// expected-error@+1 {{elements hex data size is invalid for provided type}}
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<4xf64>} : () -> ()

View File

@ -245,6 +245,12 @@ def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
"$_builder.getI32IntegerAttr($_self)">;
}
def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
let arguments = (ins
StringElementsAttr:$scalar_string_attr
);
}
//===----------------------------------------------------------------------===//
// Test Attribute Constraints
//===----------------------------------------------------------------------===//