2018-10-25 15:46:10 -07:00
|
|
|
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
#include "AttributeDetail.h"
|
2018-10-25 22:13:03 -07:00
|
|
|
#include "mlir/IR/AffineMap.h"
|
2019-05-13 09:00:22 -07:00
|
|
|
#include "mlir/IR/Diagnostics.h"
|
2019-02-11 22:51:34 -08:00
|
|
|
#include "mlir/IR/Dialect.h"
|
2018-10-25 15:46:10 -07:00
|
|
|
#include "mlir/IR/Function.h"
|
2018-10-25 22:13:03 -07:00
|
|
|
#include "mlir/IR/IntegerSet.h"
|
2018-10-25 15:46:10 -07:00
|
|
|
#include "mlir/IR/Types.h"
|
2019-05-13 09:00:22 -07:00
|
|
|
#include "llvm/ADT/Twine.h"
|
2018-10-25 15:46:10 -07:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::detail;
|
|
|
|
|
2019-05-02 14:02:57 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AttributeStorage
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-05-22 13:41:23 -07:00
|
|
|
AttributeStorage::AttributeStorage(Type type)
|
|
|
|
: type(type.getAsOpaquePointer()) {}
|
|
|
|
AttributeStorage::AttributeStorage() : type(nullptr) {}
|
2019-05-02 14:02:57 -07:00
|
|
|
|
|
|
|
Type AttributeStorage::getType() const {
|
2019-05-22 13:41:23 -07:00
|
|
|
return Type::getFromOpaquePointer(type);
|
2019-05-02 14:02:57 -07:00
|
|
|
}
|
2019-05-22 13:41:23 -07:00
|
|
|
void AttributeStorage::setType(Type newType) {
|
|
|
|
type = newType.getAsOpaquePointer();
|
2019-05-02 14:02:57 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Attribute
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-04-30 14:26:04 -07:00
|
|
|
/// Return the type of this attribute.
|
2019-05-08 22:25:15 -07:00
|
|
|
Type Attribute::getType() const { return impl->getType(); }
|
2019-04-30 14:26:04 -07:00
|
|
|
|
2019-05-06 12:40:43 -07:00
|
|
|
/// Return the context this attribute belongs to.
|
|
|
|
MLIRContext *Attribute::getContext() const { return getType().getContext(); }
|
|
|
|
|
2019-05-10 15:14:13 -07:00
|
|
|
/// Get the dialect this attribute is registered to.
|
2019-05-22 09:56:11 -07:00
|
|
|
Dialect &Attribute::getDialect() const { return impl->getDialect(); }
|
2019-05-10 15:14:13 -07:00
|
|
|
|
2019-05-15 09:10:52 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-18 16:41:00 -07:00
|
|
|
// AffineMapAttr
|
2019-05-15 09:10:52 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
AffineMapAttr AffineMapAttr::get(AffineMap value) {
|
|
|
|
return Base::get(value.getResult(0).getContext(),
|
|
|
|
StandardAttributes::AffineMap, value);
|
2019-05-15 09:10:52 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
|
2019-05-15 09:10:52 -07:00
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ArrayAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-05-15 09:10:52 -07:00
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
|
|
|
|
return Base::get(context, StandardAttributes::Array, value);
|
2019-05-15 09:10:52 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// BoolAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
bool BoolAttr::getValue() const { return getImpl()->value; }
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-05-31 09:24:48 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DictionaryAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Perform a three-way comparison between the names of the specified
|
|
|
|
/// NamedAttributes.
|
|
|
|
static int compareNamedAttributes(const NamedAttribute *lhs,
|
|
|
|
const NamedAttribute *rhs) {
|
|
|
|
return lhs->first.str().compare(rhs->first.str());
|
|
|
|
}
|
|
|
|
|
|
|
|
DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
|
|
|
|
MLIRContext *context) {
|
|
|
|
assert(llvm::all_of(value,
|
|
|
|
[](const NamedAttribute &attr) { return attr.second; }) &&
|
|
|
|
"value cannot have null entries");
|
|
|
|
|
|
|
|
// We need to sort the element list to canonicalize it, but we also don't want
|
|
|
|
// to do a ton of work in the super common case where the element list is
|
|
|
|
// already sorted.
|
|
|
|
SmallVector<NamedAttribute, 8> storage;
|
|
|
|
switch (value.size()) {
|
|
|
|
case 0:
|
|
|
|
break;
|
|
|
|
case 1:
|
|
|
|
// A single element is already sorted.
|
|
|
|
break;
|
|
|
|
case 2:
|
|
|
|
assert(value[0].first != value[1].first &&
|
|
|
|
"DictionaryAttr element names must be unique");
|
|
|
|
|
|
|
|
// Don't invoke a general sort for two element case.
|
|
|
|
if (value[0].first.strref() > value[1].first.strref()) {
|
|
|
|
storage.push_back(value[1]);
|
|
|
|
storage.push_back(value[0]);
|
|
|
|
value = storage;
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
// Check to see they are sorted already.
|
|
|
|
bool isSorted = true;
|
|
|
|
for (unsigned i = 0, e = value.size() - 1; i != e; ++i) {
|
|
|
|
if (value[i].first.strref() > value[i + 1].first.strref()) {
|
|
|
|
isSorted = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// If not, do a general sort.
|
|
|
|
if (!isSorted) {
|
|
|
|
storage.append(value.begin(), value.end());
|
|
|
|
llvm::array_pod_sort(storage.begin(), storage.end(),
|
|
|
|
compareNamedAttributes);
|
|
|
|
value = storage;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ensure that the attribute elements are unique.
|
|
|
|
assert(std::adjacent_find(value.begin(), value.end(),
|
|
|
|
[](NamedAttribute l, NamedAttribute r) {
|
|
|
|
return l.first == r.first;
|
|
|
|
}) == value.end() &&
|
|
|
|
"DictionaryAttr element names must be unique");
|
|
|
|
}
|
|
|
|
|
|
|
|
return Base::get(context, StandardAttributes::Dictionary, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
|
|
|
|
return getImpl()->getElements();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return the specified attribute if present, null otherwise.
|
|
|
|
Attribute DictionaryAttr::get(StringRef name) const {
|
|
|
|
for (auto elt : getValue())
|
|
|
|
if (elt.first.is(name))
|
|
|
|
return elt.second;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
Attribute DictionaryAttr::get(Identifier name) const {
|
|
|
|
for (auto elt : getValue())
|
|
|
|
if (elt.first == name)
|
|
|
|
return elt.second;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
DictionaryAttr::iterator DictionaryAttr::begin() const {
|
|
|
|
return getValue().begin();
|
|
|
|
}
|
|
|
|
DictionaryAttr::iterator DictionaryAttr::end() const {
|
|
|
|
return getValue().end();
|
|
|
|
}
|
|
|
|
size_t DictionaryAttr::size() const { return getValue().size(); }
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// FloatAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
FloatAttr FloatAttr::get(Type type, double value) {
|
2019-05-13 12:34:42 -07:00
|
|
|
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
|
2019-05-13 12:34:42 -07:00
|
|
|
return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
|
|
|
|
type, value);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
|
2019-05-13 12:34:42 -07:00
|
|
|
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
|
2019-05-13 12:34:42 -07:00
|
|
|
return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
|
|
|
|
type, value);
|
2018-10-25 15:46:10 -07:00
|
|
|
}
|
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
|
|
|
|
|
2018-12-27 16:51:09 -08:00
|
|
|
double FloatAttr::getValueAsDouble() const {
|
2019-05-04 15:00:27 -07:00
|
|
|
return getValueAsDouble(getValue());
|
|
|
|
}
|
|
|
|
double FloatAttr::getValueAsDouble(APFloat value) {
|
|
|
|
if (&value.getSemantics() != &APFloat::IEEEdouble()) {
|
|
|
|
bool losesInfo = false;
|
2018-12-27 16:51:09 -08:00
|
|
|
value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
|
|
|
|
&losesInfo);
|
|
|
|
}
|
|
|
|
return value.convertToDouble();
|
|
|
|
}
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
/// Verify construction invariants.
|
|
|
|
static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
|
|
|
|
Type type) {
|
|
|
|
if (!type.isa<FloatType>()) {
|
|
|
|
if (loc)
|
2019-06-25 21:31:54 -07:00
|
|
|
emitError(*loc, "expected floating point type");
|
2019-05-08 22:25:15 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult FloatAttr::verifyConstructionInvariants(
|
|
|
|
llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
|
|
|
|
return verifyFloatTypeInvariants(loc, type);
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
|
|
|
MLIRContext *ctx, Type type,
|
|
|
|
const APFloat &value) {
|
|
|
|
// Verify that the type is correct.
|
|
|
|
if (failed(verifyFloatTypeInvariants(loc, type)))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Verify that the type semantics match that of the value.
|
|
|
|
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
|
|
|
|
if (loc)
|
2019-06-25 21:31:54 -07:00
|
|
|
emitError(*loc,
|
|
|
|
"FloatAttr type doesn't match the type implied by its value");
|
2019-05-08 22:25:15 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-18 16:41:00 -07:00
|
|
|
// FunctionAttr
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
FunctionAttr FunctionAttr::get(Function *value) {
|
|
|
|
assert(value && "Cannot get FunctionAttr for a null function");
|
|
|
|
return get(value->getName(), value->getContext());
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
|
|
|
|
return Base::get(ctx, StandardAttributes::Function, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
StringRef FunctionAttr::getValue() const { return getImpl()->value; }
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-18 16:41:00 -07:00
|
|
|
// IntegerAttr
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
|
|
|
|
return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
|
|
|
|
// This uses 64 bit APInts by default for index type.
|
|
|
|
if (type.isIndex())
|
|
|
|
return get(type, APInt(64, value));
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
auto intType = type.cast<IntegerType>();
|
|
|
|
return get(type, APInt(intType.getWidth(), value));
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
|
|
|
|
|
|
|
|
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// IntegerSetAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 22:13:03 -07:00
|
|
|
|
2019-04-30 10:31:29 -07:00
|
|
|
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
|
2019-05-08 22:25:15 -07:00
|
|
|
return Base::get(value.getConstraint(0).getContext(),
|
2019-05-13 12:34:42 -07:00
|
|
|
StandardAttributes::IntegerSet, value);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
|
2018-10-25 22:13:03 -07:00
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-18 16:41:00 -07:00
|
|
|
// OpaqueAttr
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData,
|
|
|
|
MLIRContext *context) {
|
|
|
|
return Base::get(context, StandardAttributes::Opaque, dialect, attrData);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
|
|
|
|
MLIRContext *context, Location location) {
|
|
|
|
return Base::getChecked(location, context, StandardAttributes::Opaque,
|
|
|
|
dialect, attrData);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the dialect namespace of the opaque attribute.
|
|
|
|
Identifier OpaqueAttr::getDialectNamespace() const {
|
|
|
|
return getImpl()->dialectNamespace;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the raw attribute data of the opaque attribute.
|
|
|
|
StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
|
|
|
|
|
|
|
|
/// Verify the construction of an opaque attribute.
|
|
|
|
LogicalResult OpaqueAttr::verifyConstructionInvariants(
|
|
|
|
llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
|
|
|
|
StringRef attrData) {
|
|
|
|
if (!Dialect::isValidNamespace(dialect.strref())) {
|
|
|
|
if (loc)
|
2019-06-25 21:31:54 -07:00
|
|
|
emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
|
2019-06-18 16:41:00 -07:00
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-18 16:41:00 -07:00
|
|
|
// StringAttr
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
|
|
|
|
return Base::get(context, StandardAttributes::String, bytes);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
StringRef StringAttr::getValue() const { return getImpl()->value; }
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TypeAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
TypeAttr TypeAttr::get(Type value) {
|
|
|
|
return Base::get(value.getContext(), StandardAttributes::Type, value);
|
2019-05-20 15:21:22 -07:00
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
Type TypeAttr::getValue() const { return getImpl()->value; }
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ElementsAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-05-16 00:12:45 -07:00
|
|
|
ShapedType ElementsAttr::getType() const {
|
|
|
|
return Attribute::getType().cast<ShapedType>();
|
2018-10-25 15:46:10 -07:00
|
|
|
}
|
|
|
|
|
2019-02-27 16:15:16 -08:00
|
|
|
/// Return the value at the given index. If index does not refer to a valid
|
|
|
|
/// element, then a null attribute is returned.
|
|
|
|
Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|
|
|
switch (getKind()) {
|
2019-06-06 16:15:42 -07:00
|
|
|
case StandardAttributes::DenseElements:
|
2019-02-27 16:15:16 -08:00
|
|
|
return cast<DenseElementsAttr>().getValue(index);
|
2019-05-13 12:34:42 -07:00
|
|
|
case StandardAttributes::OpaqueElements:
|
2019-02-27 16:15:16 -08:00
|
|
|
return cast<OpaqueElementsAttr>().getValue(index);
|
2019-05-13 12:34:42 -07:00
|
|
|
case StandardAttributes::SparseElements:
|
2019-02-27 16:15:16 -08:00
|
|
|
return cast<SparseElementsAttr>().getValue(index);
|
|
|
|
default:
|
|
|
|
llvm_unreachable("unknown ElementsAttr kind");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-05-23 11:22:13 -07:00
|
|
|
ElementsAttr ElementsAttr::mapValues(
|
|
|
|
Type newElementType,
|
|
|
|
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
|
|
|
switch (getKind()) {
|
2019-06-06 16:15:42 -07:00
|
|
|
case StandardAttributes::DenseElements:
|
2019-05-23 11:22:13 -07:00
|
|
|
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
|
|
|
default:
|
|
|
|
llvm_unreachable("unsupported ElementsAttr subtype");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ElementsAttr ElementsAttr::mapValues(
|
|
|
|
Type newElementType,
|
|
|
|
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
|
|
|
switch (getKind()) {
|
2019-06-06 16:15:42 -07:00
|
|
|
case StandardAttributes::DenseElements:
|
2019-05-23 11:22:13 -07:00
|
|
|
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
|
|
|
default:
|
|
|
|
llvm_unreachable("unsupported ElementsAttr subtype");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-11 16:14:17 -07:00
|
|
|
// DenseElementAttr Utilities
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static size_t getDenseElementBitwidth(Type eltType) {
|
|
|
|
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
|
|
|
// with double semantics.
|
|
|
|
return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
|
|
|
}
|
|
|
|
|
2019-06-06 15:55:17 -07:00
|
|
|
/// Get the bitwidth of a dense element type within the buffer.
|
|
|
|
/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
|
|
|
|
static size_t getDenseElementStorageWidth(size_t origWidth) {
|
|
|
|
return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Set a bit to a specific value.
|
|
|
|
static void setBit(char *rawData, size_t bitPos, bool value) {
|
|
|
|
if (value)
|
|
|
|
rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
|
|
|
|
else
|
|
|
|
rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return the value of the specified bit.
|
|
|
|
static bool getBit(const char *rawData, size_t bitPos) {
|
|
|
|
return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
|
|
|
|
}
|
|
|
|
|
2019-06-07 12:08:36 -07:00
|
|
|
/// Writes value to the bit position `bitPos` in array `rawData`.
|
|
|
|
static void writeBits(char *rawData, size_t bitPos, APInt value) {
|
|
|
|
size_t bitWidth = value.getBitWidth();
|
|
|
|
|
|
|
|
// If the bitwidth is 1 we just toggle the specific bit.
|
|
|
|
if (bitWidth == 1)
|
|
|
|
return setBit(rawData, bitPos, value.isOneValue());
|
|
|
|
|
|
|
|
// Otherwise, the bit position is guaranteed to be byte aligned.
|
|
|
|
assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
|
|
|
|
std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
|
|
|
|
llvm::divideCeil(bitWidth, CHAR_BIT),
|
|
|
|
rawData + (bitPos / CHAR_BIT));
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
|
|
|
|
/// `rawData`.
|
|
|
|
static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
|
|
|
|
// Handle a boolean bit position.
|
|
|
|
if (bitWidth == 1)
|
|
|
|
return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
|
|
|
|
|
|
|
|
// Otherwise, the bit position must be 8-bit aligned.
|
|
|
|
assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
|
|
|
|
APInt result(bitWidth, 0);
|
2019-06-13 16:02:36 -07:00
|
|
|
std::copy_n(
|
|
|
|
rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT),
|
|
|
|
const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
|
2019-06-07 12:08:36 -07:00
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2019-06-11 16:14:17 -07:00
|
|
|
/// Returns if 'values' corresponds to a splat, i.e. one element, or has the
|
|
|
|
/// same element count as 'type'.
|
|
|
|
template <typename Values>
|
|
|
|
static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
|
|
|
|
return (values.size() == 1) ||
|
|
|
|
(type.getNumElements() == static_cast<int64_t>(values.size()));
|
|
|
|
}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
/// Constructs a new iterator.
|
2019-06-13 13:22:32 -07:00
|
|
|
DenseElementsAttr::IntElementIterator::IntElementIterator(
|
2019-04-05 16:11:24 -07:00
|
|
|
DenseElementsAttr attr, size_t index)
|
2019-06-25 12:10:46 -07:00
|
|
|
: indexed_accessor_iterator<IntElementIterator, const char *, APInt, APInt,
|
|
|
|
APInt>(attr.getRawData().data(), index),
|
2019-04-05 16:11:24 -07:00
|
|
|
bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
|
|
|
|
|
|
|
|
/// Accesses the raw APInt value at this iterator position.
|
2019-06-13 13:22:32 -07:00
|
|
|
APInt DenseElementsAttr::IntElementIterator::operator*() const {
|
2019-06-25 12:10:46 -07:00
|
|
|
return readBits(object, index * getDenseElementStorageWidth(bitWidth),
|
2019-06-06 15:55:17 -07:00
|
|
|
bitWidth);
|
2019-04-05 16:11:24 -07:00
|
|
|
}
|
|
|
|
|
2019-06-13 13:22:32 -07:00
|
|
|
DenseElementsAttr::FloatElementIterator::FloatElementIterator(
|
|
|
|
const llvm::fltSemantics &smt, IntElementIterator it)
|
|
|
|
: llvm::mapped_iterator<IntElementIterator,
|
|
|
|
std::function<APFloat(const APInt &)>>(
|
|
|
|
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DenseElementsAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-05-16 00:12:45 -07:00
|
|
|
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
2019-04-30 10:31:29 -07:00
|
|
|
ArrayRef<Attribute> values) {
|
|
|
|
assert(type.getElementType().isIntOrFloat() &&
|
|
|
|
"expected int or float element type");
|
2019-06-11 16:14:17 -07:00
|
|
|
assert(hasSameElementsOrSplat(type, values));
|
2019-04-30 10:31:29 -07:00
|
|
|
|
|
|
|
auto eltType = type.getElementType();
|
2019-06-06 15:55:17 -07:00
|
|
|
size_t bitWidth = getDenseElementBitwidth(eltType);
|
|
|
|
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
2019-04-30 10:31:29 -07:00
|
|
|
|
|
|
|
// Compress the attribute values into a character buffer.
|
2019-06-13 13:22:32 -07:00
|
|
|
SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
|
|
|
|
values.size());
|
2019-04-30 10:31:29 -07:00
|
|
|
APInt intVal;
|
|
|
|
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
2019-06-11 16:14:17 -07:00
|
|
|
assert(eltType == values[i].getType() &&
|
|
|
|
"expected attribute value to have element type");
|
|
|
|
|
2019-04-30 10:31:29 -07:00
|
|
|
switch (eltType.getKind()) {
|
|
|
|
case StandardTypes::BF16:
|
|
|
|
case StandardTypes::F16:
|
|
|
|
case StandardTypes::F32:
|
|
|
|
case StandardTypes::F64:
|
|
|
|
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
|
|
|
|
break;
|
|
|
|
case StandardTypes::Integer:
|
2019-06-11 16:14:17 -07:00
|
|
|
intVal = values[i].isa<BoolAttr>()
|
|
|
|
? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
|
|
|
|
: values[i].cast<IntegerAttr>().getValue();
|
2019-04-30 10:31:29 -07:00
|
|
|
break;
|
|
|
|
default:
|
|
|
|
llvm_unreachable("unexpected element type");
|
|
|
|
}
|
|
|
|
assert(intVal.getBitWidth() == bitWidth &&
|
|
|
|
"expected value to have same bitwidth as element type");
|
2019-06-06 15:55:17 -07:00
|
|
|
writeBits(data.data(), i * storageBitWidth, intVal);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
2019-06-11 16:14:17 -07:00
|
|
|
return getRaw(type, data, /*isSplat=*/(values.size() == 1));
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-06-07 12:08:36 -07:00
|
|
|
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
|
|
|
ArrayRef<bool> values) {
|
2019-06-11 16:14:17 -07:00
|
|
|
assert(hasSameElementsOrSplat(type, values));
|
2019-06-07 12:08:36 -07:00
|
|
|
assert(type.getElementType().isInteger(1));
|
|
|
|
|
|
|
|
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
|
|
|
|
for (int i = 0, e = values.size(); i != e; ++i)
|
2019-06-11 16:14:17 -07:00
|
|
|
setBit(buff.data(), i, values[i]);
|
|
|
|
return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Constructs a dense integer elements attribute from an array of APInt
|
|
|
|
/// values. Each APInt value is expected to have the same bitwidth as the
|
|
|
|
/// element type of 'type'.
|
|
|
|
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
|
|
|
ArrayRef<APInt> values) {
|
|
|
|
assert(type.getElementType().isa<IntegerType>());
|
|
|
|
return getRaw(type, values);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Constructs a dense float elements attribute from an array of APFloat
|
|
|
|
// values. Each APFloat value is expected to have the same bitwidth as the
|
|
|
|
// element type of 'type'.
|
|
|
|
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
|
|
|
ArrayRef<APFloat> values) {
|
|
|
|
assert(type.getElementType().isa<FloatType>());
|
|
|
|
|
|
|
|
// Convert the APFloat values to APInt and create a dense elements attribute.
|
|
|
|
std::vector<APInt> intValues(values.size());
|
|
|
|
for (unsigned i = 0, e = values.size(); i != e; ++i)
|
|
|
|
intValues[i] = values[i].bitcastToAPInt();
|
|
|
|
return getRaw(type, intValues);
|
2019-06-07 12:08:36 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Constructs a dense elements attribute from an array of raw APInt values.
|
|
|
|
// Each APInt value is expected to have the same bitwidth as the element type
|
|
|
|
// of 'type'.
|
2019-06-11 16:14:17 -07:00
|
|
|
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
|
|
|
ArrayRef<APInt> values) {
|
|
|
|
assert(hasSameElementsOrSplat(type, values));
|
2019-06-07 12:08:36 -07:00
|
|
|
|
|
|
|
size_t bitWidth = getDenseElementBitwidth(type.getElementType());
|
|
|
|
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
2019-06-13 13:22:32 -07:00
|
|
|
std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
|
|
|
|
values.size());
|
2019-06-07 12:08:36 -07:00
|
|
|
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
|
|
|
assert(values[i].getBitWidth() == bitWidth);
|
|
|
|
writeBits(elementData.data(), i * storageBitWidth, values[i]);
|
|
|
|
}
|
2019-06-11 16:14:17 -07:00
|
|
|
return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
|
2019-06-07 12:08:36 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
2019-06-11 16:14:17 -07:00
|
|
|
ArrayRef<char> data, bool isSplat) {
|
2019-06-07 12:08:36 -07:00
|
|
|
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
|
|
|
|
"type must be ranked tensor or vector");
|
|
|
|
assert(type.hasStaticShape() && "type must have static shape");
|
|
|
|
return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
|
2019-06-11 16:14:17 -07:00
|
|
|
data, isSplat);
|
2019-06-07 12:08:36 -07:00
|
|
|
}
|
|
|
|
|
2019-06-13 13:22:32 -07:00
|
|
|
/// Check the information for a c++ data type, check if this type is valid for
|
|
|
|
/// the current attribute. This method is used to verify specific type
|
|
|
|
/// invariants that the templatized 'getValues' method cannot.
|
|
|
|
static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
|
|
|
|
bool isInt) {
|
|
|
|
// Make sure that the data element size is the same as the type element width.
|
|
|
|
if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// Check that the element type is valid.
|
|
|
|
return isInt ? type.getElementType().isa<IntegerType>()
|
|
|
|
: type.getElementType().isa<FloatType>();
|
|
|
|
}
|
|
|
|
|
2019-06-07 12:08:36 -07:00
|
|
|
/// Overload of the 'getRaw' method that asserts that the given type is of
|
2019-06-11 16:14:17 -07:00
|
|
|
/// integer type. This method is used to verify type invariants that the
|
|
|
|
/// templatized 'get' method cannot.
|
2019-06-07 12:08:36 -07:00
|
|
|
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
|
|
|
|
ArrayRef<char> data,
|
2019-06-11 16:14:17 -07:00
|
|
|
int64_t dataEltSize,
|
2019-06-07 12:08:36 -07:00
|
|
|
bool isInt) {
|
2019-06-13 13:22:32 -07:00
|
|
|
assert(::isValidIntOrFloat(type, dataEltSize, isInt));
|
2019-06-11 16:14:17 -07:00
|
|
|
|
|
|
|
int64_t numElements = data.size() / dataEltSize;
|
|
|
|
assert(numElements == 1 || numElements == type.getNumElements());
|
|
|
|
return getRaw(type, data, /*isSplat=*/numElements == 1);
|
2019-06-07 12:08:36 -07:00
|
|
|
}
|
|
|
|
|
2019-06-13 13:22:32 -07:00
|
|
|
/// A method used to verify specific type invariants that the templatized 'get'
|
|
|
|
/// method cannot.
|
|
|
|
bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
|
|
|
|
bool isInt) const {
|
|
|
|
return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
|
|
|
|
}
|
|
|
|
|
2019-06-07 12:08:36 -07:00
|
|
|
/// Return the raw storage data held by this attribute.
|
|
|
|
ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|
|
|
return static_cast<ImplType *>(impl)->data;
|
|
|
|
}
|
|
|
|
|
2019-06-11 16:14:17 -07:00
|
|
|
/// Returns the number of raw elements held by this attribute.
|
|
|
|
size_t DenseElementsAttr::rawSize() const {
|
|
|
|
return isSplat() ? 1 : getType().getNumElements();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns if this attribute corresponds to a splat, i.e. if all element
|
|
|
|
/// values are the same.
|
|
|
|
bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
|
|
|
|
|
|
|
|
/// If this attribute corresponds to a splat, then get the splat value.
|
|
|
|
/// Otherwise, return null.
|
|
|
|
Attribute DenseElementsAttr::getSplatValue() const {
|
|
|
|
if (!isSplat())
|
|
|
|
return Attribute();
|
|
|
|
|
|
|
|
auto elementType = getType().getElementType();
|
|
|
|
if (elementType.isa<IntegerType>())
|
2019-06-13 13:22:32 -07:00
|
|
|
return IntegerAttr::get(elementType, *raw_int_begin());
|
2019-06-11 16:14:17 -07:00
|
|
|
if (auto fType = elementType.dyn_cast<FloatType>())
|
|
|
|
return FloatAttr::get(elementType,
|
2019-06-13 13:22:32 -07:00
|
|
|
APFloat(fType.getFloatSemantics(), *raw_int_begin()));
|
2019-06-11 16:14:17 -07:00
|
|
|
llvm_unreachable("unexpected element type");
|
|
|
|
}
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-01-19 20:54:09 -08:00
|
|
|
/// Return the value at the given index. If index does not refer to a valid
|
|
|
|
/// element, then a null attribute is returned.
|
|
|
|
Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|
|
|
auto type = getType();
|
|
|
|
|
|
|
|
// Verify that the rank of the indices matches the held type.
|
|
|
|
auto rank = type.getRank();
|
2019-05-31 17:18:59 -07:00
|
|
|
if (rank != static_cast<int64_t>(index.size()))
|
2019-01-19 20:54:09 -08:00
|
|
|
return Attribute();
|
|
|
|
|
|
|
|
// Verify that all of the indices are within the shape dimensions.
|
|
|
|
auto shape = type.getShape();
|
|
|
|
for (unsigned i = 0; i != rank; ++i)
|
2019-05-03 19:48:57 -07:00
|
|
|
if (shape[i] <= static_cast<int64_t>(index[i]))
|
2019-01-19 20:54:09 -08:00
|
|
|
return Attribute();
|
|
|
|
|
2019-06-11 16:14:17 -07:00
|
|
|
// If this is a splat, return the splat value directly.
|
|
|
|
if (isSplat())
|
|
|
|
return getSplatValue();
|
|
|
|
|
2019-01-19 20:54:09 -08:00
|
|
|
// Reduce the provided multidimensional index into a 1D index.
|
|
|
|
uint64_t valueIndex = 0;
|
|
|
|
uint64_t dimMultiplier = 1;
|
2019-05-03 19:48:57 -07:00
|
|
|
for (int i = rank - 1; i >= 0; --i) {
|
2019-01-19 20:54:09 -08:00
|
|
|
valueIndex += index[i] * dimMultiplier;
|
|
|
|
dimMultiplier *= shape[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Return the element stored at the 1D index.
|
|
|
|
auto elementType = getType().getElementType();
|
2019-04-05 16:11:24 -07:00
|
|
|
size_t bitWidth = getDenseElementBitwidth(elementType);
|
2019-06-11 16:14:17 -07:00
|
|
|
size_t storageWidth = getDenseElementStorageWidth(bitWidth);
|
2019-01-19 20:54:09 -08:00
|
|
|
APInt rawValueData =
|
2019-06-11 16:14:17 -07:00
|
|
|
readBits(getRawData().data(), valueIndex * storageWidth, bitWidth);
|
2019-01-19 20:54:09 -08:00
|
|
|
|
|
|
|
// Convert the raw value data to an attribute value.
|
2019-06-06 16:15:42 -07:00
|
|
|
if (elementType.isa<IntegerType>())
|
2019-01-19 20:54:09 -08:00
|
|
|
return IntegerAttr::get(elementType, rawValueData);
|
2019-06-06 16:15:42 -07:00
|
|
|
if (auto fType = elementType.dyn_cast<FloatType>())
|
|
|
|
return FloatAttr::get(elementType,
|
|
|
|
APFloat(fType.getFloatSemantics(), rawValueData));
|
|
|
|
llvm_unreachable("unexpected element type");
|
2019-01-19 20:54:09 -08:00
|
|
|
}
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
2019-06-13 13:22:32 -07:00
|
|
|
values.reserve(rawSize());
|
|
|
|
|
2019-01-17 14:11:05 -08:00
|
|
|
auto elementType = getType().getElementType();
|
2019-06-06 16:15:42 -07:00
|
|
|
if (elementType.isa<IntegerType>()) {
|
2019-06-13 13:22:32 -07:00
|
|
|
// Convert each value to an IntegerAttr.
|
|
|
|
for (auto intVal : getIntValues())
|
2019-01-17 14:11:05 -08:00
|
|
|
values.push_back(IntegerAttr::get(elementType, intVal));
|
2018-10-25 15:46:10 -07:00
|
|
|
return;
|
2019-01-17 14:11:05 -08:00
|
|
|
}
|
2019-06-06 16:15:42 -07:00
|
|
|
if (elementType.isa<FloatType>()) {
|
2019-06-13 13:22:32 -07:00
|
|
|
// Convert each value to a FloatAttr.
|
|
|
|
for (auto floatVal : getFloatValues())
|
2019-01-17 14:11:05 -08:00
|
|
|
values.push_back(FloatAttr::get(elementType, floatVal));
|
2018-10-25 15:46:10 -07:00
|
|
|
return;
|
2019-01-17 14:11:05 -08:00
|
|
|
}
|
2019-06-06 16:15:42 -07:00
|
|
|
llvm_unreachable("unexpected element type");
|
2018-10-25 15:46:10 -07:00
|
|
|
}
|
|
|
|
|
2019-06-13 13:22:32 -07:00
|
|
|
/// Return the held element values as a range of APInts. The element type of
|
|
|
|
/// this attribute must be of integer type.
|
|
|
|
auto DenseElementsAttr::getIntValues() const
|
|
|
|
-> llvm::iterator_range<IntElementIterator> {
|
|
|
|
assert(getType().getElementType().isa<IntegerType>() &&
|
|
|
|
"expected integer type");
|
|
|
|
return {raw_int_begin(), raw_int_end()};
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return the held element values as a range of APFloat. The element type of
|
|
|
|
/// this attribute must be of float type.
|
|
|
|
auto DenseElementsAttr::getFloatValues() const
|
|
|
|
-> llvm::iterator_range<FloatElementIterator> {
|
|
|
|
auto elementType = getType().getElementType().cast<FloatType>();
|
|
|
|
assert(elementType.isa<FloatType>() && "expected float type");
|
|
|
|
const auto &elementSemantics = elementType.getFloatSemantics();
|
|
|
|
return {FloatElementIterator(elementSemantics, raw_int_begin()),
|
|
|
|
FloatElementIterator(elementSemantics, raw_int_end())};
|
|
|
|
}
|
|
|
|
|
2019-06-07 09:57:29 -07:00
|
|
|
/// Return a new DenseElementsAttr that has the same data as the current
|
|
|
|
/// attribute, but has been reshaped to 'newType'. The new type must have the
|
|
|
|
/// same total number of elements as well as element type.
|
|
|
|
DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
|
|
|
|
ShapedType curType = getType();
|
|
|
|
if (curType == newType)
|
|
|
|
return *this;
|
|
|
|
|
|
|
|
(void)curType;
|
|
|
|
assert(newType.getElementType() == curType.getElementType() &&
|
|
|
|
"expected the same element type");
|
|
|
|
assert(newType.getNumElements() == curType.getNumElements() &&
|
|
|
|
"expected the same number of elements");
|
2019-06-11 16:14:17 -07:00
|
|
|
return getRaw(newType, getRawData(), isSplat());
|
2019-06-07 09:57:29 -07:00
|
|
|
}
|
|
|
|
|
2019-05-23 11:22:13 -07:00
|
|
|
DenseElementsAttr DenseElementsAttr::mapValues(
|
|
|
|
Type newElementType,
|
|
|
|
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
|
|
|
return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
|
|
|
|
}
|
|
|
|
|
|
|
|
DenseElementsAttr DenseElementsAttr::mapValues(
|
|
|
|
Type newElementType,
|
|
|
|
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
|
|
|
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
|
|
|
|
}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-18 16:41:00 -07:00
|
|
|
// DenseFPElementsAttr
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-12-18 05:25:17 -08:00
|
|
|
|
2019-05-31 09:24:48 -07:00
|
|
|
template <typename Fn, typename Attr>
|
|
|
|
static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
|
|
|
|
Type newElementType,
|
|
|
|
llvm::SmallVectorImpl<char> &data) {
|
2019-05-23 11:22:13 -07:00
|
|
|
size_t bitWidth = getDenseElementBitwidth(newElementType);
|
2019-06-06 15:55:17 -07:00
|
|
|
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
2019-05-23 11:22:13 -07:00
|
|
|
|
|
|
|
ShapedType newArrayType;
|
|
|
|
if (inType.isa<RankedTensorType>())
|
|
|
|
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
|
|
|
|
else if (inType.isa<UnrankedTensorType>())
|
|
|
|
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
|
|
|
|
else if (inType.isa<VectorType>())
|
|
|
|
newArrayType = VectorType::get(inType.getShape(), newElementType);
|
|
|
|
else
|
|
|
|
assert(newArrayType && "Unhandled tensor type");
|
|
|
|
|
2019-06-13 13:22:32 -07:00
|
|
|
data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * attr.rawSize());
|
2019-05-23 11:22:13 -07:00
|
|
|
|
|
|
|
uint64_t elementIdx = 0;
|
|
|
|
for (auto value : attr) {
|
|
|
|
auto newInt = mapping(value);
|
|
|
|
assert(newInt.getBitWidth() == bitWidth);
|
2019-06-07 12:08:36 -07:00
|
|
|
writeBits(data.data(), elementIdx * storageBitWidth, newInt);
|
2019-05-23 11:22:13 -07:00
|
|
|
++elementIdx;
|
|
|
|
}
|
|
|
|
|
|
|
|
return newArrayType;
|
|
|
|
}
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
DenseElementsAttr DenseFPElementsAttr::mapValues(
|
2019-05-23 11:22:13 -07:00
|
|
|
Type newElementType,
|
2019-06-18 16:41:00 -07:00
|
|
|
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
2019-05-23 11:22:13 -07:00
|
|
|
llvm::SmallVector<char, 8> elementData;
|
2019-05-31 09:24:48 -07:00
|
|
|
auto newArrayType =
|
|
|
|
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
2019-05-23 11:22:13 -07:00
|
|
|
|
2019-06-11 16:14:17 -07:00
|
|
|
return getRaw(newArrayType, elementData, isSplat());
|
2019-05-23 11:22:13 -07:00
|
|
|
}
|
|
|
|
|
2019-06-06 16:15:42 -07:00
|
|
|
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
2019-06-18 16:41:00 -07:00
|
|
|
bool DenseFPElementsAttr::classof(Attribute attr) {
|
2019-06-06 16:15:42 -07:00
|
|
|
return attr.isa<DenseElementsAttr>() &&
|
2019-06-18 16:41:00 -07:00
|
|
|
attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
2019-06-06 16:15:42 -07:00
|
|
|
}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-18 16:41:00 -07:00
|
|
|
// DenseIntElementsAttr
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-06-18 16:41:00 -07:00
|
|
|
DenseElementsAttr DenseIntElementsAttr::mapValues(
|
2019-05-23 11:22:13 -07:00
|
|
|
Type newElementType,
|
2019-06-18 16:41:00 -07:00
|
|
|
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
2019-05-23 11:22:13 -07:00
|
|
|
llvm::SmallVector<char, 8> elementData;
|
2019-05-31 09:24:48 -07:00
|
|
|
auto newArrayType =
|
|
|
|
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
2019-05-23 11:22:13 -07:00
|
|
|
|
2019-06-11 16:14:17 -07:00
|
|
|
return getRaw(newArrayType, elementData, isSplat());
|
2019-05-23 11:22:13 -07:00
|
|
|
}
|
|
|
|
|
2019-06-06 16:15:42 -07:00
|
|
|
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
2019-06-18 16:41:00 -07:00
|
|
|
bool DenseIntElementsAttr::classof(Attribute attr) {
|
2019-06-06 16:15:42 -07:00
|
|
|
return attr.isa<DenseElementsAttr>() &&
|
2019-06-18 16:41:00 -07:00
|
|
|
attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
|
2019-06-06 16:15:42 -07:00
|
|
|
}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// OpaqueElementsAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-05-16 00:12:45 -07:00
|
|
|
OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
|
2019-04-30 10:31:29 -07:00
|
|
|
StringRef bytes) {
|
|
|
|
assert(TensorType::isValidElementType(type.getElementType()) &&
|
|
|
|
"Input element type should be a valid tensor element type");
|
2019-05-13 12:34:42 -07:00
|
|
|
return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
|
2019-05-08 22:25:15 -07:00
|
|
|
dialect, bytes);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-02-25 08:21:41 -08:00
|
|
|
/// Return the value at the given index. If index does not refer to a valid
|
|
|
|
/// element, then a null attribute is returned.
|
|
|
|
Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|
|
|
if (Dialect *dialect = getDialect())
|
|
|
|
return dialect->extractElementHook(*this, index);
|
|
|
|
return Attribute();
|
|
|
|
}
|
|
|
|
|
2019-05-08 22:25:15 -07:00
|
|
|
Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
|
2019-02-11 22:51:34 -08:00
|
|
|
|
|
|
|
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
|
|
|
if (auto *d = getDialect())
|
|
|
|
return d->decodeHook(*this, result);
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// SparseElementsAttr
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-10-25 15:46:10 -07:00
|
|
|
|
2019-05-16 00:12:45 -07:00
|
|
|
SparseElementsAttr SparseElementsAttr::get(ShapedType type,
|
2019-06-11 16:14:17 -07:00
|
|
|
DenseElementsAttr indices,
|
2019-04-30 10:31:29 -07:00
|
|
|
DenseElementsAttr values) {
|
|
|
|
assert(indices.getType().getElementType().isInteger(64) &&
|
|
|
|
"expected sparse indices to be 64-bit integer values");
|
2019-05-29 15:34:50 -07:00
|
|
|
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
|
|
|
|
"type must be ranked tensor or vector");
|
|
|
|
assert(type.hasStaticShape() && "type must have static shape");
|
2019-05-13 12:34:42 -07:00
|
|
|
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
|
2019-06-11 16:14:17 -07:00
|
|
|
indices.cast<DenseIntElementsAttr>(), values);
|
2019-04-30 10:31:29 -07:00
|
|
|
}
|
|
|
|
|
2018-10-25 15:46:10 -07:00
|
|
|
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
2019-05-08 22:25:15 -07:00
|
|
|
return getImpl()->indices;
|
2018-10-25 15:46:10 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
DenseElementsAttr SparseElementsAttr::getValues() const {
|
2019-05-08 22:25:15 -07:00
|
|
|
return getImpl()->values;
|
2018-10-25 15:46:10 -07:00
|
|
|
}
|
2019-01-19 20:54:09 -08:00
|
|
|
|
|
|
|
/// Return the value of the element at the given index.
|
|
|
|
Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|
|
|
auto type = getType();
|
|
|
|
|
|
|
|
// Verify that the rank of the indices matches the held type.
|
2019-05-03 19:48:57 -07:00
|
|
|
size_t rank = type.getRank();
|
2019-01-19 20:54:09 -08:00
|
|
|
if (rank != index.size())
|
|
|
|
return Attribute();
|
|
|
|
|
2019-06-11 16:14:17 -07:00
|
|
|
/// Return an attribute corresponding to '0' for the element type.
|
|
|
|
auto getZeroAttr = [=]() -> Attribute {
|
|
|
|
auto eltType = type.getElementType();
|
|
|
|
if (eltType.isa<FloatType>())
|
|
|
|
return FloatAttr::get(eltType, 0);
|
|
|
|
assert(eltType.isa<IntegerType>() && "unexpected element type");
|
|
|
|
return IntegerAttr::get(eltType, 0);
|
|
|
|
};
|
|
|
|
|
2019-01-19 20:54:09 -08:00
|
|
|
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
|
|
|
// as a 1-D index array.
|
|
|
|
auto sparseIndices = getIndices();
|
2019-06-13 13:22:32 -07:00
|
|
|
ArrayRef<uint64_t> sparseIndexValues = sparseIndices.getValues<uint64_t>();
|
2019-01-19 20:54:09 -08:00
|
|
|
|
2019-06-11 16:14:17 -07:00
|
|
|
// Check to see if the indices are a splat.
|
|
|
|
if (sparseIndices.isSplat()) {
|
|
|
|
// If the index is also not a splat of the index value, we know that the
|
|
|
|
// value is zero.
|
2019-06-13 13:22:32 -07:00
|
|
|
auto splatIndex = sparseIndexValues.front();
|
2019-06-11 16:14:17 -07:00
|
|
|
if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
|
|
|
|
return getZeroAttr();
|
|
|
|
|
|
|
|
// If the indices are a splat, we also expect the values to be a splat.
|
|
|
|
assert(getValues().isSplat() && "expected splat values");
|
|
|
|
return getValues().getSplatValue();
|
|
|
|
}
|
|
|
|
|
2019-01-19 20:54:09 -08:00
|
|
|
// Build a mapping between known indices and the offset of the stored element.
|
|
|
|
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
|
2019-01-23 14:39:45 -08:00
|
|
|
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
|
2019-01-19 20:54:09 -08:00
|
|
|
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
|
2019-06-13 13:22:32 -07:00
|
|
|
mappedIndices.try_emplace({&sparseIndexValues[i * rank], rank}, i);
|
2019-01-19 20:54:09 -08:00
|
|
|
|
|
|
|
// Look for the provided index key within the mapped indices. If the provided
|
|
|
|
// index is not found, then return a zero attribute.
|
|
|
|
auto it = mappedIndices.find(index);
|
2019-06-11 16:14:17 -07:00
|
|
|
if (it == mappedIndices.end())
|
|
|
|
return getZeroAttr();
|
2019-01-19 20:54:09 -08:00
|
|
|
|
|
|
|
// Otherwise, return the held sparse value element.
|
|
|
|
return getValues().getValue(it->second);
|
|
|
|
}
|
2019-02-26 18:01:46 -08:00
|
|
|
|
2019-04-05 16:11:24 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// NamedAttributeList
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-02-26 18:01:46 -08:00
|
|
|
|
2019-05-06 12:40:43 -07:00
|
|
|
NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
|
|
|
|
setAttrs(attributes);
|
2019-02-26 18:01:46 -08:00
|
|
|
}
|
|
|
|
|
2019-06-01 08:37:50 -07:00
|
|
|
ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
|
|
|
|
return attrs ? attrs.getValue() : llvm::None;
|
|
|
|
}
|
|
|
|
|
2019-02-26 18:01:46 -08:00
|
|
|
/// Replace the held attributes with ones provided in 'newAttrs'.
|
2019-05-06 12:40:43 -07:00
|
|
|
void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
|
2019-02-26 18:01:46 -08:00
|
|
|
// Don't create an attribute list if there are no attributes.
|
2019-05-31 09:24:48 -07:00
|
|
|
if (attributes.empty())
|
2019-02-26 18:01:46 -08:00
|
|
|
attrs = nullptr;
|
2019-05-31 09:24:48 -07:00
|
|
|
else
|
|
|
|
attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
|
2019-02-26 18:01:46 -08:00
|
|
|
}
|
|
|
|
|
2019-05-31 19:52:18 -07:00
|
|
|
/// Return the specified attribute if present, null otherwise.
|
|
|
|
Attribute NamedAttributeList::get(StringRef name) const {
|
|
|
|
return attrs ? attrs.get(name) : nullptr;
|
|
|
|
}
|
|
|
|
|
2019-02-26 18:01:46 -08:00
|
|
|
/// Return the specified attribute if present, null otherwise.
|
|
|
|
Attribute NamedAttributeList::get(Identifier name) const {
|
2019-05-31 09:24:48 -07:00
|
|
|
return attrs ? attrs.get(name) : nullptr;
|
2019-02-26 18:01:46 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// If the an attribute exists with the specified name, change it to the new
|
|
|
|
/// value. Otherwise, add a new attribute with the specified name/value.
|
2019-05-06 12:40:43 -07:00
|
|
|
void NamedAttributeList::set(Identifier name, Attribute value) {
|
2019-02-26 18:01:46 -08:00
|
|
|
assert(value && "attributes may never be null");
|
|
|
|
|
|
|
|
// If we already have this attribute, replace it.
|
|
|
|
auto origAttrs = getAttrs();
|
|
|
|
SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
|
|
|
|
for (auto &elt : newAttrs)
|
|
|
|
if (elt.first == name) {
|
|
|
|
elt.second = value;
|
2019-05-31 09:24:48 -07:00
|
|
|
attrs = DictionaryAttr::get(newAttrs, value.getContext());
|
2019-02-26 18:01:46 -08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, add it.
|
|
|
|
newAttrs.push_back({name, value});
|
2019-05-31 09:24:48 -07:00
|
|
|
attrs = DictionaryAttr::get(newAttrs, value.getContext());
|
2019-02-26 18:01:46 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Remove the attribute with the specified name if it exists. The return
|
|
|
|
/// value indicates whether the attribute was present or not.
|
2019-05-06 12:40:43 -07:00
|
|
|
auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
|
2019-02-26 18:01:46 -08:00
|
|
|
auto origAttrs = getAttrs();
|
|
|
|
for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
|
|
|
|
if (origAttrs[i].first == name) {
|
2019-05-06 12:40:43 -07:00
|
|
|
// Handle the simple case of removing the only attribute in the list.
|
|
|
|
if (e == 1) {
|
|
|
|
attrs = nullptr;
|
|
|
|
return RemoveResult::Removed;
|
|
|
|
}
|
|
|
|
|
2019-02-26 18:01:46 -08:00
|
|
|
SmallVector<NamedAttribute, 8> newAttrs;
|
|
|
|
newAttrs.reserve(origAttrs.size() - 1);
|
|
|
|
newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
|
|
|
|
newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
|
2019-05-31 09:24:48 -07:00
|
|
|
attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
|
2019-02-26 18:01:46 -08:00
|
|
|
return RemoveResult::Removed;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return RemoveResult::NotFound;
|
|
|
|
}
|