Add new indexed_accessor_range_base and indexed_accessor_range classes that simplify defining index-able ranges.

Many ranges want similar functionality from a range type(e.g. slice/drop_front/operator[]/etc.), so these classes provide a generic implementation that may be used by many different types of ranges. This removes some code duplication, and also empowers many of the existing range types in MLIR(e.g. result type ranges, operand ranges, ElementsAttr ranges, etc.). This change only updates RegionRange and ValueRange, more ranges will be updated in followup commits.

PiperOrigin-RevId: 284615679
This commit is contained in:
River Riddle 2019-12-09 12:55:05 -08:00 committed by A. Unique TensorFlower
parent 56da74476c
commit 7be6a40ab9
8 changed files with 207 additions and 142 deletions

View File

@ -641,12 +641,12 @@ protected:
/// Return the current index for this iterator, adjusted for the case of a
/// splat.
ptrdiff_t getDataIndex() const {
bool isSplat = this->object.getInt();
bool isSplat = this->base.getInt();
return isSplat ? 0 : this->index;
}
/// Return the data object pointer.
const char *getData() const { return this->object.getPointer(); }
/// Return the data base pointer.
const char *getData() const { return this->base.getPointer(); }
};
} // namespace detail

View File

@ -460,9 +460,9 @@ public:
Block *>(object, index) {}
SuccessorIterator(const SuccessorIterator &other)
: SuccessorIterator(other.object, other.index) {}
: SuccessorIterator(other.base, other.index) {}
Block *operator*() const { return this->object->getSuccessor(this->index); }
Block *operator*() const { return this->base->getSuccessor(this->index); }
/// Get the successor number in the terminator.
unsigned getSuccessorIndex() const { return this->index; }

View File

@ -668,7 +668,7 @@ public:
: indexed_accessor_iterator<OperandIterator, Operation *, Value *,
Value *, Value *>(object, index) {}
Value *operator*() const { return this->object->getOperand(this->index); }
Value *operator*() const { return this->base->getOperand(this->index); }
};
/// This class implements the operand type iterators for the Operation
@ -721,11 +721,11 @@ class ResultIterator final
Value *, Value *> {
public:
/// Initializes the result iterator to the specified index.
ResultIterator(Operation *object, unsigned index)
ResultIterator(Operation *base, unsigned index)
: indexed_accessor_iterator<ResultIterator, Operation *, Value *, Value *,
Value *>(object, index) {}
Value *>(base, index) {}
Value *operator*() const { return this->object->getResult(this->index); }
Value *operator*() const { return this->base->getResult(this->index); }
};
/// This class implements the result type iterators for the Operation
@ -799,15 +799,19 @@ inline auto Operation::getResultTypes() -> result_type_range {
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class ValueRange {
class ValueRange
: public detail::indexed_accessor_range_base<
ValueRange,
llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
Value *, Value *> {
/// The type representing the owner of this range. This is either a list of
/// values, operands, or results.
using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
public:
ValueRange(const ValueRange &) = default;
ValueRange(ValueRange &&) = default;
ValueRange &operator=(const ValueRange &) = default;
using detail::indexed_accessor_range_base<
ValueRange, OwnerT, Value *, Value *,
Value *>::indexed_accessor_range_base;
template <typename Arg,
typename = typename std::enable_if_t<
@ -822,46 +826,15 @@ public:
ValueRange(iterator_range<OperandIterator> values);
ValueRange(iterator_range<ResultIterator> values);
/// An iterator element of this range.
class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Value *,
Value *, Value *> {
public:
Value *operator*() const;
private:
Iterator(OwnerT owner, unsigned curIndex);
/// Allow access to the constructor.
friend ValueRange;
};
Iterator begin() const { return Iterator(owner, 0); }
Iterator end() const { return Iterator(owner, count); }
Value *operator[](unsigned index) const {
assert(index < size() && "invalid index for value range");
return *std::next(begin(), index);
}
/// Return the size of this range.
size_t size() const { return count; }
/// Return if the range is empty.
bool empty() const { return size() == 0; }
/// Drop the first N elements, and keep M elements.
ValueRange slice(unsigned n, unsigned m) const;
/// Drop the first n elements.
ValueRange drop_front(unsigned n = 1) const;
/// Drop the last n elements.
ValueRange drop_back(unsigned n = 1) const;
private:
ValueRange(OwnerT owner, unsigned count) : owner(owner), count(count) {}
/// See `detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
/// See `detail::indexed_accessor_range_base` for details.
static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
/// The object that owns the provided range of values.
OwnerT owner;
/// The size from the owning range.
unsigned count;
/// Allow access to `offset_base` and `dereference_iterator`.
friend detail::indexed_accessor_range_base<ValueRange, OwnerT, Value *,
Value *, Value *>;
};
} // end namespace mlir

View File

@ -165,14 +165,19 @@ private:
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class RegionRange {
class RegionRange
: public detail::indexed_accessor_range_base<
RegionRange,
llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>,
Region *, Region *, Region *> {
/// The type representing the owner of this range. This is either a list of
/// values, operands, or results.
using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>;
public:
RegionRange(const RegionRange &) = default;
RegionRange(RegionRange &&) = default;
using detail::indexed_accessor_range_base<
RegionRange, OwnerT, Region *, Region *,
Region *>::indexed_accessor_range_base;
RegionRange(MutableArrayRef<Region> regions = llvm::None);
@ -184,33 +189,15 @@ public:
}
RegionRange(ArrayRef<std::unique_ptr<Region>> regions);
/// An iterator element of this range.
class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Region *,
Region *, Region *> {
public:
Region *operator*() const;
private:
Iterator(OwnerT owner, unsigned curIndex);
/// Allow access to the constructor.
friend RegionRange;
};
Iterator begin() const { return Iterator(owner, 0); }
Iterator end() const { return Iterator(owner, count); }
Region *operator[](unsigned index) const {
assert(index < size() && "invalid index for region range");
return *std::next(begin(), index);
}
/// Return the size of this range.
size_t size() const { return count; }
/// Return if the range is empty.
bool empty() const { return size() == 0; }
private:
/// The object that owns the provided range of regions.
OwnerT owner;
/// The size from the owning range.
unsigned count;
/// See `detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
/// See `detail::indexed_accessor_range_base` for details.
static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
/// Allow access to `offset_base` and `dereference_iterator`.
friend detail::indexed_accessor_range_base<RegionRange, OwnerT, Region *,
Region *, Region *>;
};
} // end namespace mlir

View File

@ -147,9 +147,9 @@ using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
// Extra additions to <iterator>
//===----------------------------------------------------------------------===//
/// A utility class used to implement an iterator that contains some object and
/// an index. The iterator moves the index but keeps the object constant.
template <typename DerivedT, typename ObjectType, typename T,
/// A utility class used to implement an iterator that contains some base object
/// and an index. The iterator moves the index but keeps the base constant.
template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_iterator
: public llvm::iterator_facade_base<DerivedT,
@ -157,14 +157,14 @@ class indexed_accessor_iterator
std::ptrdiff_t, PointerT, ReferenceT> {
public:
ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const {
assert(object == rhs.object && "incompatible iterators");
assert(base == rhs.base && "incompatible iterators");
return index - rhs.index;
}
bool operator==(const indexed_accessor_iterator &rhs) const {
return object == rhs.object && index == rhs.index;
return base == rhs.base && index == rhs.index;
}
bool operator<(const indexed_accessor_iterator &rhs) const {
assert(object == rhs.object && "incompatible iterators");
assert(base == rhs.base && "incompatible iterators");
return index < rhs.index;
}
@ -180,16 +180,134 @@ public:
/// Returns the current index of the iterator.
ptrdiff_t getIndex() const { return index; }
/// Returns the current object of the iterator.
const ObjectType &getObject() const { return object; }
/// Returns the current base of the iterator.
const BaseT &getBase() const { return base; }
protected:
indexed_accessor_iterator(ObjectType object, ptrdiff_t index)
: object(object), index(index) {}
ObjectType object;
indexed_accessor_iterator(BaseT base, ptrdiff_t index)
: base(base), index(index) {}
BaseT base;
ptrdiff_t index;
};
namespace detail {
/// The class represents the base of a range of indexed_accessor_iterators. It
/// provides support for many different range functionalities, e.g.
/// drop_front/slice/etc.. Derived range classes must implement the following
/// static methods:
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
/// - Derefence an iterator pointing to the base object at the given index.
/// * BaseT offset_base(const BaseT &base, ptrdiff_t index)
/// - Return a new base that is offset from the provide base by 'index'
/// elements.
template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_range_base {
public:
/// An iterator element of this range.
class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
PointerT, ReferenceT> {
public:
// Index into this iterator, invoking a static method on the derived type.
ReferenceT operator*() const {
return DerivedT::dereference_iterator(this->getBase(), this->getIndex());
}
private:
iterator(BaseT owner, ptrdiff_t curIndex)
: indexed_accessor_iterator<iterator, BaseT, T, PointerT, ReferenceT>(
owner, curIndex) {}
/// Allow access to the constructor.
friend indexed_accessor_range_base<DerivedT, BaseT, T, PointerT,
ReferenceT>;
};
iterator begin() const { return iterator(base, 0); }
iterator end() const { return iterator(base, count); }
ReferenceT operator[](unsigned index) const {
assert(index < size() && "invalid index for value range");
return *std::next(begin(), index);
}
/// Return the size of this range.
size_t size() const { return count; }
/// Return if the range is empty.
bool empty() const { return size() == 0; }
/// Drop the first N elements, and keep M elements.
DerivedT slice(unsigned n, unsigned m) const {
assert(n + m <= size() && "invalid size specifiers");
return DerivedT(DerivedT::offset_base(base, n), m);
}
/// Drop the first n elements.
DerivedT drop_front(unsigned n = 1) const {
assert(size() >= n && "Dropping more elements than exist");
return slice(n, size() - n);
}
/// Drop the last n elements.
DerivedT drop_back(unsigned n = 1) const {
assert(size() >= n && "Dropping more elements than exist");
return DerivedT(base, size() - n);
}
protected:
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
: base(base), count(count) {}
indexed_accessor_range_base(const indexed_accessor_range_base &) = default;
indexed_accessor_range_base(indexed_accessor_range_base &&) = default;
indexed_accessor_range_base &
operator=(const indexed_accessor_range_base &) = default;
/// The base that owns the provided range of values.
BaseT base;
/// The size from the owning range.
ptrdiff_t count;
};
} // end namespace detail
/// This class provides an implementation of a range of
/// indexed_accessor_iterators where the base is not indexable. Ranges with
/// bases that are offsetable should derive from indexed_accessor_range_base
/// instead. Derived range classes are expected to implement the following
/// static method:
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
/// - Derefence an iterator pointing to a parent base at the given index.
template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_range
: public detail::indexed_accessor_range_base<
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
protected:
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
: detail::indexed_accessor_range_base<
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
std::make_pair(base, startIndex), count) {}
private:
/// See `detail::indexed_accessor_range_base` for details.
static std::pair<BaseT, ptrdiff_t>
offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
// We encode the internal base as a pair of the derived base and a start
// index into the derived base.
return std::make_pair(base.first, base.second + index);
}
/// See `detail::indexed_accessor_range_base` for details.
static ReferenceT
dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
ptrdiff_t index) {
return DerivedT::dereference_iterator(base.first, base.second + index);
}
/// Allow access to `offset_base` and `dereference_iterator`.
friend detail::indexed_accessor_range_base<
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>;
};
/// Given a container of pairs, return a range over the second elements.
template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
return llvm::map_range(

View File

@ -527,7 +527,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
/// Accesses the Attribute value at this iterator position.
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = getFromOpaquePointer(object).cast<DenseElementsAttr>();
auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
Type eltTy = owner.getType().getElementType();
if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
if (intEltTy.getWidth() == 1)

View File

@ -750,60 +750,41 @@ Operation *Operation::clone() {
//===----------------------------------------------------------------------===//
ValueRange::ValueRange(ArrayRef<Value *> values)
: owner(values.data()), count(values.size()) {}
: ValueRange(values.data(), values.size()) {}
ValueRange::ValueRange(llvm::iterator_range<OperandIterator> values)
: count(llvm::size(values)) {
if (count != 0) {
: ValueRange(nullptr, llvm::size(values)) {
if (!empty()) {
auto begin = values.begin();
owner = &begin.getObject()->getOpOperand(begin.getIndex());
base = &begin.getBase()->getOpOperand(begin.getIndex());
}
}
ValueRange::ValueRange(llvm::iterator_range<ResultIterator> values)
: count(llvm::size(values)) {
if (count != 0) {
: ValueRange(nullptr, llvm::size(values)) {
if (!empty()) {
auto begin = values.begin();
owner = &begin.getObject()->getOpResult(begin.getIndex());
base = &begin.getBase()->getOpResult(begin.getIndex());
}
}
/// Drop the first N elements, and keep M elements.
ValueRange ValueRange::slice(unsigned n, unsigned m) const {
assert(n + m <= size() && "Invalid specifier");
OwnerT newOwner;
/// See `detail::indexed_accessor_range_base` for details.
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
newOwner = operand + n;
else if (OpResult *result = owner.dyn_cast<OpResult *>())
newOwner = result + n;
else
newOwner = owner.get<Value *const *>() + n;
return ValueRange(newOwner, m);
return operand + index;
if (OpResult *result = owner.dyn_cast<OpResult *>())
return result + index;
return owner.get<Value *const *>() + index;
}
/// Drop the first n elements.
ValueRange ValueRange::drop_front(unsigned n) const {
assert(size() >= n && "Dropping more elements than exist");
return slice(n, size() - n);
}
/// Drop the last n elements.
ValueRange ValueRange::drop_back(unsigned n) const {
assert(size() >= n && "Dropping more elements than exist");
return ValueRange(owner, size() - n);
}
ValueRange::Iterator::Iterator(OwnerT owner, unsigned curIndex)
: indexed_accessor_iterator<Iterator, OwnerT, Value *, Value *, Value *>(
owner, curIndex) {}
Value *ValueRange::Iterator::operator*() const {
/// See `detail::indexed_accessor_range_base` for details.
Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
// Operands access the held value via 'get'.
if (OpOperand *operand = object.dyn_cast<OpOperand *>())
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
return operand[index].get();
// An OpResult is a value, so we can return it directly.
if (OpResult *result = object.dyn_cast<OpResult *>())
if (OpResult *result = owner.dyn_cast<OpResult *>())
return &result[index];
// Otherwise, this is a raw value array so just index directly.
return object.get<Value *const *>()[index];
return owner.get<Value *const *>()[index];
}
//===----------------------------------------------------------------------===//

View File

@ -217,17 +217,23 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList(
//===----------------------------------------------------------------------===//
// RegionRange
//===----------------------------------------------------------------------===//
RegionRange::RegionRange(MutableArrayRef<Region> regions)
: owner(regions.data()), count(regions.size()) {}
RegionRange::RegionRange(ArrayRef<std::unique_ptr<Region>> regions)
: owner(regions.data()), count(regions.size()) {}
RegionRange::Iterator::Iterator(OwnerT owner, unsigned curIndex)
: indexed_accessor_iterator<Iterator, OwnerT, Region *, Region *, Region *>(
owner, curIndex) {}
Region *RegionRange::Iterator::operator*() const {
if (const std::unique_ptr<Region> *operand =
object.dyn_cast<const std::unique_ptr<Region> *>())
return operand[index].get();
return &object.get<Region *>()[index];
RegionRange::RegionRange(MutableArrayRef<Region> regions)
: RegionRange(regions.data(), regions.size()) {}
RegionRange::RegionRange(ArrayRef<std::unique_ptr<Region>> regions)
: RegionRange(regions.data(), regions.size()) {}
/// See `detail::indexed_accessor_range_base` for details.
RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
if (auto *operand = owner.dyn_cast<const std::unique_ptr<Region> *>())
return operand + index;
return &owner.get<Region *>()[index];
}
/// See `detail::indexed_accessor_range_base` for details.
Region *RegionRange::dereference_iterator(const OwnerT &owner,
ptrdiff_t index) {
if (auto *operand = owner.dyn_cast<const std::unique_ptr<Region> *>())
return operand[index].get();
return &owner.get<Region *>()[index];
}