Start building SDBM infrastructure

Striped difference-bound matrix expressions are a subset of affine expressions
    supporting low-complexity algorithms that can be useful for loop
    transformations.  This introduces the basic data data structures for building
    such expressions and unique'ing them in a MLIRContext.

--

PiperOrigin-RevId: 245380206
This commit is contained in:
Alex Zinenko 2019-04-26 01:05:24 -07:00 committed by Mehdi Amini
parent 65ccb8cfd5
commit 24d0f60d31
6 changed files with 940 additions and 0 deletions

View File

@ -0,0 +1,353 @@
//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
// or a difference expression between two identifiers.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SDBMEXPR_H
#define MLIR_IR_SDBMEXPR_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace mlir {
class MLIRContext;
enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
namespace detail {
struct SDBMExprStorage;
struct SDBMBinaryExprStorage;
struct SDBMDiffExprStorage;
struct SDBMPositiveExprStorage;
struct SDBMConstantExprStorage;
struct SDBMNegExprStorage;
} // namespace detail
/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
/// expression for the SDBM framework. SDBM expressions are a subset of affine
/// expressions supporting low-complexity algorithms for the operations used in
/// loop transformations. In particular, are supported:
/// - constant expressions;
/// - single variables (dimensions and symbols) with +1 or -1 coefficient;
/// - stripe expressions: "x # C", where "x" is a single variable or another
/// stripe expression, "#" is the stripe operator, and "C" is a constant
/// expression; "#" is defined as x - x mod C.
/// - sum expressions between single variable/stripe expressions and constant
/// expressions;
/// - difference expressions between single variable/stripe expressions.
/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
/// and operating on SDBM expressions. For example, it requires the LHS of a
/// sum expression to be a single variable or a stripe expression. These
/// restrictions are intended to force the caller to perform the necessary
/// simplifications to stay within the SDBM domain, because SDBM expressions do
/// not combine in more cases than they do. This choice may be reconsidered in
/// the future.
///
/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
/// an MLIRContext, and should be used by-value. They are uniqued in the
/// MLIRContext and immortal.
class SDBMExpr {
public:
using ImplType = detail::SDBMExprStorage;
SDBMExpr() : impl(nullptr) {}
/* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
/// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
/// which makes them trivially assignable and trivially copyable.
SDBMExpr(const SDBMExpr &) = default;
SDBMExpr &operator=(const SDBMExpr &) = default;
/// SDBM expressions can be compared straight-forwardly.
bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
/// SDBM expressions are convertible to `bool`: null expressions are converted
/// to false, non-null expressions are converted to true.
explicit operator bool() const { return impl != nullptr; }
bool operator!() const { return !static_cast<bool>(*this); }
/// Prints the SDBM expression.
void print(raw_ostream &os) const;
void dump() const;
/// LLVM-style casts.
template <typename U> bool isa() const { return U::isClassFor(*this); }
template <typename U> U dyn_cast() const {
if (!isa<U>())
return {};
return U(const_cast<SDBMExpr *>(this)->impl);
}
template <typename U> U cast() const {
assert(isa<U>() && "cast to incorrect subtype");
return U(const_cast<SDBMExpr *>(this)->impl);
}
/// Support for LLVM hashing.
::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
/// Returns the kind of the SDBM expression.
SDBMExprKind getKind() const;
/// Returns the MLIR context in which this expression lives.
MLIRContext *getContext() const;
protected:
ImplType *impl;
};
/// SDBM constant expression, wraps a 64-bit integer.
class SDBMConstantExpr : public SDBMExpr {
public:
using ImplType = detail::SDBMConstantExprStorage;
using SDBMExpr::SDBMExpr;
/// Obtain or create a constant expression unique'ed in the given context.
static SDBMConstantExpr get(MLIRContext *context, int64_t value);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Constant;
}
int64_t getValue() const;
};
/// SDBM varying expression can be one of:
/// - input variable expression;
/// - stripe expression;
/// - negation (product with -1) of either of the above.
/// - sum of a varying and a constant expression
/// - difference between varying expressions
class SDBMVaryingExpr : public SDBMExpr {
public:
using ImplType = detail::SDBMExprStorage;
using SDBMExpr::SDBMExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Neg ||
expr.getKind() == SDBMExprKind::Stripe ||
expr.getKind() == SDBMExprKind::Add ||
expr.getKind() == SDBMExprKind::Diff;
}
};
/// SDBM positive variable expression can be one of:
/// - single variable expression;
/// - stripe expression.
class SDBMPositiveExpr : public SDBMVaryingExpr {
public:
using SDBMVaryingExpr::SDBMVaryingExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId ||
expr.getKind() == SDBMExprKind::Stripe;
}
};
/// SDBM sum expression. LHS is a varying expression and RHS is always a
/// constant expression.
class SDBMSumExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMBinaryExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a sum expression unique'ed in the given context.
static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
static bool isClassFor(const SDBMExpr &expr) {
SDBMExprKind kind = expr.getKind();
return kind == SDBMExprKind::Add;
}
SDBMVaryingExpr getLHS() const;
SDBMConstantExpr getRHS() const;
};
/// SDBM difference expression. Both LHS and RHS are positive variable
/// expressions.
class SDBMDiffExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMDiffExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a difference expression unique'ed in the given context.
static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Diff;
}
SDBMPositiveExpr getLHS() const;
SDBMPositiveExpr getRHS() const;
};
/// SDBM stripe expression "x # C" where "x" is a positive variable expression,
/// "C" is a constant expression and "#" is the stripe operator defined as:
/// x # C = x - x mod C.
class SDBMStripeExpr : public SDBMPositiveExpr {
public:
using ImplType = detail::SDBMBinaryExprStorage;
using SDBMPositiveExpr::SDBMPositiveExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Stripe;
}
static SDBMStripeExpr get(SDBMPositiveExpr var,
SDBMConstantExpr stripeFactor);
SDBMPositiveExpr getVar() const;
SDBMConstantExpr getStripeFactor() const;
};
/// SDBM "input" variable expression can be either a dimension identifier or
/// a symbol identifier. When used to define SDBM functions, dimensions are
/// interpreted as function arguments while symbols are treated as unknown but
/// constant values, hence the name.
class SDBMInputExpr : public SDBMPositiveExpr {
public:
using ImplType = detail::SDBMPositiveExprStorage;
using SDBMPositiveExpr::SDBMPositiveExpr;
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId ||
expr.getKind() == SDBMExprKind::SymbolId;
}
unsigned getPosition() const;
};
/// SDBM dimension expression. Dimensions correspond to function arguments
/// when defining functions using SDBM expressions.
class SDBMDimExpr : public SDBMInputExpr {
public:
using ImplType = detail::SDBMPositiveExprStorage;
using SDBMInputExpr::SDBMInputExpr;
/// Obtain or create a dimension expression unique'ed in the given context.
static SDBMDimExpr get(MLIRContext *context, unsigned position);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::DimId;
}
};
/// SDBM symbol expression. Symbols correspond to symbolic constants when
/// defining functions using SDBM expressions.
class SDBMSymbolExpr : public SDBMInputExpr {
public:
using ImplType = detail::SDBMPositiveExprStorage;
using SDBMInputExpr::SDBMInputExpr;
/// Obtain or create a symbol expression unique'ed in the given context.
static SDBMSymbolExpr get(MLIRContext *context, unsigned position);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::SymbolId;
}
};
/// Negation of an SDBM variable expression. Equivalent to multiplying the
/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
class SDBMNegExpr : public SDBMVaryingExpr {
public:
using ImplType = detail::SDBMNegExprStorage;
using SDBMVaryingExpr::SDBMVaryingExpr;
/// Obtain or create a negation expression unique'ed in the given context.
static SDBMNegExpr get(SDBMPositiveExpr var);
static bool isClassFor(const SDBMExpr &expr) {
return expr.getKind() == SDBMExprKind::Neg;
}
SDBMPositiveExpr getVar() const;
};
} // end namespace mlir
namespace llvm {
// SDBMVaryingExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> {
static mlir::SDBMVaryingExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMVaryingExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMVaryingExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMVaryingExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMVaryingExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) {
return lhs == rhs;
}
};
// SDBMPositiveExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMPositiveExpr> {
static mlir::SDBMPositiveExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMPositiveExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMPositiveExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMPositiveExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMPositiveExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) {
return lhs == rhs;
}
};
// SDBMConstantExpr hash just like pointers.
template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
static mlir::SDBMConstantExpr getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::SDBMConstantExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static mlir::SDBMConstantExpr getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::SDBMConstantExpr(
static_cast<mlir::SDBMExpr::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
return expr.hash_value();
}
static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
return lhs == rhs;
}
};
} // namespace llvm
#endif // MLIR_IR_SDBMEXPR_H

View File

@ -21,6 +21,7 @@
#include "AttributeDetail.h"
#include "IntegerSetDetail.h"
#include "LocationDetail.h"
#include "SDBMExprDetail.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@ -461,6 +462,26 @@ public:
// Uniqui'ing of AffineConstantExprStorage using constant value as key.
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
//===--------------------------------------------------------------------===//
// SDBM uniquing
//===--------------------------------------------------------------------===//
llvm::BumpPtrAllocator SDBMAllocator;
llvm::sys::SmartRWMutex<true> SDBMMutex;
DenseMap<std::tuple<SDBMVaryingExpr, SDBMConstantExpr>,
SDBMBinaryExprStorage *>
SDBMSumExprs;
DenseMap<std::tuple<SDBMPositiveExpr, SDBMConstantExpr>,
SDBMBinaryExprStorage *>
SDBMStripeExprs;
DenseMap<std::tuple<SDBMPositiveExpr, SDBMPositiveExpr>,
SDBMDiffExprStorage *>
SDBMDiffExprs;
std::vector<SDBMPositiveExprStorage *> SDBMDimExprs;
std::vector<SDBMPositiveExprStorage *> SDBMSymbolExprs;
DenseMap<SDBMPositiveExpr, SDBMNegExprStorage *> SDBMNegExprs;
DenseMap<int64_t, SDBMConstantExprStorage *> SDBMConstExprs;
//===--------------------------------------------------------------------===//
// Type uniquing
//===--------------------------------------------------------------------===//
@ -843,6 +864,103 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
});
}
//===----------------------------------------------------------------------===//
// SDBMExpr uniquing
//===----------------------------------------------------------------------===//
SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
assert(lhs && "expected SDBM variable expression");
assert(rhs && "expected SDBM constant");
MLIRContextImpl &impl = lhs.getContext()->getImpl();
// If LHS of a sum is another sum, fold the constant RHS parts.
if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
lhs = lhsSum.getLHS();
rhs = SDBMConstantExpr::get(rhs.getContext(),
rhs.getValue() + lhsSum.getRHS().getValue());
}
auto key = std::make_tuple(lhs, rhs);
return safeGetOrCreate(
impl.SDBMSumExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
return new (mem) SDBMBinaryExprStorage(SDBMExprKind::Add,
lhs.getContext(), lhs, rhs);
});
}
SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
assert(lhs && "expected SDBM dimension");
assert(rhs && "expected SDBM dimension");
MLIRContextImpl &impl = lhs.getContext()->getImpl();
auto key = std::make_tuple(lhs, rhs);
return safeGetOrCreate(
impl.SDBMDiffExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
auto *mem = impl.SDBMAllocator.Allocate<SDBMDiffExprStorage>();
return new (mem) SDBMDiffExprStorage(lhs.getContext(), lhs, rhs);
});
}
SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
SDBMConstantExpr stripeFactor) {
assert(var && "expected SDBM variable expression");
assert(stripeFactor && "expected non-null stripe factor");
assert(stripeFactor.getValue() > 0 && "non-positive stripe factor");
MLIRContextImpl &impl = var.getContext()->getImpl();
auto key = std::make_tuple(var, stripeFactor);
return safeGetOrCreate(
impl.SDBMStripeExprs, key, impl.SDBMMutex, [&impl, var, stripeFactor] {
auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
return new (mem) SDBMBinaryExprStorage(
SDBMExprKind::Stripe, var.getContext(), var, stripeFactor);
});
}
SDBMDimExpr SDBMDimExpr::get(MLIRContext *context, unsigned position) {
assert(context && "expected non-null context");
MLIRContextImpl &impl = context->getImpl();
return safeGetOrCreate(
impl.SDBMDimExprs, position, impl.SDBMMutex, [&impl, context, position] {
auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
return new (mem)
SDBMPositiveExprStorage(SDBMExprKind::DimId, context, position);
});
}
SDBMSymbolExpr SDBMSymbolExpr::get(MLIRContext *context, unsigned position) {
assert(context && "expected non-null context");
MLIRContextImpl &impl = context->getImpl();
return safeGetOrCreate(
impl.SDBMSymbolExprs, position, impl.SDBMMutex,
[&impl, context, position] {
auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
return new (mem)
SDBMPositiveExprStorage(SDBMExprKind::SymbolId, context, position);
});
}
SDBMConstantExpr SDBMConstantExpr::get(MLIRContext *context, int64_t value) {
assert(context && "expected non-null context");
MLIRContextImpl &impl = context->getImpl();
return safeGetOrCreate(
impl.SDBMConstExprs, value, impl.SDBMMutex, [&impl, context, value] {
auto *mem = impl.SDBMAllocator.Allocate<SDBMConstantExprStorage>();
return new (mem) SDBMConstantExprStorage(context, value);
});
}
SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
assert(var && "expected non-null SDBM variable expression");
MLIRContextImpl &impl = var.getContext()->getImpl();
return safeGetOrCreate(impl.SDBMNegExprs, var, impl.SDBMMutex, [&impl, var] {
auto *mem = impl.SDBMAllocator.Allocate<SDBMNegExprStorage>();
return new (mem) SDBMNegExprStorage(var);
});
}
//===----------------------------------------------------------------------===//
// Type uniquing
//===----------------------------------------------------------------------===//

202
mlir/lib/IR/SDBMExpr.cpp Normal file
View File

@ -0,0 +1,202 @@
//===- SDBMExpr.h - MLIR SDBM Expression implementation -------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
// or a difference expression between two identifiers.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/SDBMExpr.h"
#include "SDBMExprDetail.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// SDBMExpr
//===----------------------------------------------------------------------===//
SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); }
template <typename Derived> class SDBMVisitor {
public:
/// Visit the given SDBM expression, dispatching to kind-specific functions.
void visit(SDBMExpr expr) {
auto *derived = static_cast<Derived *>(this);
switch (expr.getKind()) {
case SDBMExprKind::Add:
case SDBMExprKind::Diff:
case SDBMExprKind::DimId:
case SDBMExprKind::SymbolId:
case SDBMExprKind::Neg:
case SDBMExprKind::Stripe:
return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
case SDBMExprKind::Constant:
return derived->visitConstant(expr.cast<SDBMConstantExpr>());
}
}
protected:
/// Default visitors do nothing.
void visitSum(SDBMSumExpr) {}
void visitDiff(SDBMDiffExpr) {}
void visitStripe(SDBMStripeExpr) {}
void visitDim(SDBMDimExpr) {}
void visitSymbol(SDBMSymbolExpr) {}
void visitNeg(SDBMNegExpr) {}
void visitConstant(SDBMConstantExpr) {}
/// Default implementation of visitPositive dispatches to the special
/// functions for stripes and other variables. Concrete visitors can override
/// it.
void visitPositive(SDBMPositiveExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::Stripe)
derived->visitStripe(expr.cast<SDBMStripeExpr>());
else
derived->visitInput(expr.cast<SDBMInputExpr>());
}
/// Default implementation of visitInput dispatches to the special
/// functions for dimensions or symbols. Concrete visitors can override it to
/// visit all variables instead.
void visitInput(SDBMInputExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::DimId)
derived->visitDim(expr.cast<SDBMDimExpr>());
else
derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
}
/// Default implementation of visitVarying dispatches to the special
/// functions for variables and negations thereof. Concerete visitors can
/// override it to visit all variables and negations isntead.
void visitVarying(SDBMVaryingExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
derived->visitPositive(var);
else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
derived->visitNeg(neg);
else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
derived->visitSum(sum);
else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
derived->visitDiff(diff);
llvm_unreachable("unhandled subtype of varying SDBM expression");
}
};
void SDBMExpr::print(raw_ostream &os) const {
struct Printer : public SDBMVisitor<Printer> {
Printer(raw_ostream &ostream) : prn(ostream) {}
void visitSum(SDBMSumExpr expr) {
visitVarying(expr.getLHS());
prn << " + ";
visitConstant(expr.getRHS());
}
void visitDiff(SDBMDiffExpr expr) {
visitPositive(expr.getLHS());
prn << " - ";
visitPositive(expr.getRHS());
}
void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
void visitStripe(SDBMStripeExpr expr) {
visitPositive(expr.getVar());
prn << " # ";
visitConstant(expr.getStripeFactor());
}
void visitNeg(SDBMNegExpr expr) {
prn << '-';
visitPositive(expr.getVar());
}
void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
raw_ostream &prn;
};
Printer printer(os);
printer.visit(*this);
}
void SDBMExpr::dump() const { print(llvm::errs()); }
//===----------------------------------------------------------------------===//
// SDBMSumExpr
//===----------------------------------------------------------------------===//
SDBMVaryingExpr SDBMSumExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs;
}
SDBMConstantExpr SDBMSumExpr::getRHS() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMDiffExpr
//===----------------------------------------------------------------------===//
SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
return static_cast<ImplType *>(impl)->lhs;
}
SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMStripeExpr
//===----------------------------------------------------------------------===//
SDBMPositiveExpr SDBMStripeExpr::getVar() const {
if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
return lhs.cast<SDBMPositiveExpr>();
return {};
}
SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
return static_cast<ImplType *>(impl)->rhs;
}
//===----------------------------------------------------------------------===//
// SDBMInputExpr
//===----------------------------------------------------------------------===//
unsigned SDBMInputExpr::getPosition() const {
return static_cast<ImplType *>(impl)->position;
}
//===----------------------------------------------------------------------===//
// SDBMConstantExpr
//===----------------------------------------------------------------------===//
int64_t SDBMConstantExpr::getValue() const {
return static_cast<ImplType *>(impl)->constant;
}
//===----------------------------------------------------------------------===//
// SDBMNegExpr
//===----------------------------------------------------------------------===//
SDBMPositiveExpr SDBMNegExpr::getVar() const {
return static_cast<ImplType *>(impl)->dim;
}

View File

@ -0,0 +1,84 @@
//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This holds implementation details of SDBMExpr, in particular underlying
// storage types. MLIRContext.cpp needs to know the storage layout for
// allocation and unique'ing purposes.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SDBMEXPRDETAIL_H
#define MLIR_IR_SDBMEXPRDETAIL_H
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SDBMExpr.h"
#include "llvm/ADT/PointerIntPair.h"
namespace mlir {
namespace detail {
struct SDBMExprStorage {
SDBMExprStorage(SDBMExprKind kind, MLIRContext *context)
: contextAndKind(context, kind) {}
SDBMExprKind getKind() { return contextAndKind.getInt(); }
MLIRContext *getContext() { return contextAndKind.getPointer(); }
// This needs to know the layout of MLIRContext so the relevant file is
// included.
llvm::PointerIntPair<MLIRContext *, 3, SDBMExprKind> contextAndKind;
};
struct SDBMBinaryExprStorage : public SDBMExprStorage {
SDBMBinaryExprStorage(SDBMExprKind kind, MLIRContext *context,
SDBMVaryingExpr left, SDBMConstantExpr right)
: SDBMExprStorage(kind, context), lhs(left), rhs(right) {}
SDBMVaryingExpr lhs;
SDBMConstantExpr rhs;
};
struct SDBMDiffExprStorage : public SDBMExprStorage {
SDBMDiffExprStorage(MLIRContext *context, SDBMPositiveExpr left,
SDBMPositiveExpr right)
: SDBMExprStorage(SDBMExprKind::Diff, context), lhs(left), rhs(right) {}
SDBMPositiveExpr lhs;
SDBMPositiveExpr rhs;
};
struct SDBMConstantExprStorage : public SDBMExprStorage {
SDBMConstantExprStorage(MLIRContext *context, int64_t value)
: SDBMExprStorage(SDBMExprKind::Constant, context), constant(value) {}
int64_t constant;
};
struct SDBMPositiveExprStorage : public SDBMExprStorage {
SDBMPositiveExprStorage(SDBMExprKind kind, MLIRContext *context, unsigned pos)
: SDBMExprStorage(kind, context), position(pos) {}
unsigned position;
};
struct SDBMNegExprStorage : public SDBMExprStorage {
SDBMNegExprStorage(SDBMPositiveExpr expr)
: SDBMExprStorage(SDBMExprKind::Neg, expr.getContext()), dim(expr) {}
SDBMPositiveExpr dim;
};
} // end namespace detail
} // end namespace mlir
#endif // MLIR_IR_SDBMEXPRDETAIL_H

View File

@ -1,6 +1,7 @@
add_mlir_unittest(MLIRIRTests
DialectTest.cpp
OperationSupportTest.cpp
SDBMTest.cpp
)
target_link_libraries(MLIRIRTests
PRIVATE

View File

@ -0,0 +1,182 @@
//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SDBMExpr.h"
#include "gtest/gtest.h"
using namespace mlir;
static MLIRContext *ctx() {
static thread_local MLIRContext context;
return &context;
}
namespace {
TEST(SDBMExpr, Constant) {
// We can create consants and query them.
auto expr = SDBMConstantExpr::get(ctx(), 42);
EXPECT_EQ(expr.getValue(), 42);
// Two separately created constants with identical values are trivially equal.
auto expr2 = SDBMConstantExpr::get(ctx(), 42);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
}
TEST(SDBMExpr, Dim) {
// We can create dimension expressions and query them.
auto expr = SDBMDimExpr::get(ctx(), 0);
EXPECT_EQ(expr.getPosition(), 0);
// Two separately created dimensions with the same position are trivially
// equal.
auto expr2 = SDBMDimExpr::get(ctx(), 0);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMDimExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
auto symbol = SDBMSymbolExpr::get(ctx(), 0);
EXPECT_NE(expr, symbol);
EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
}
TEST(SDBMExpr, Symbol) {
// We can create symbol expressions and query them.
auto expr = SDBMSymbolExpr::get(ctx(), 0);
EXPECT_EQ(expr.getPosition(), 0);
// Two separately created symbols with the same position are trivially equal.
auto expr2 = SDBMSymbolExpr::get(ctx(), 0);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
auto symbol = SDBMDimExpr::get(ctx(), 0);
EXPECT_NE(expr, symbol);
EXPECT_FALSE(expr.isa<SDBMDimExpr>());
}
TEST(SDBMExpr, Stripe) {
auto cst2 = SDBMConstantExpr::get(ctx(), 2);
auto cst0 = SDBMConstantExpr::get(ctx(), 0);
auto var = SDBMSymbolExpr::get(ctx(), 0);
// We can create stripe expressions and query them.
auto expr = SDBMStripeExpr::get(var, cst2);
EXPECT_EQ(expr.getVar(), var);
EXPECT_EQ(expr.getStripeFactor(), cst2);
// Two separately created stripe expressions with the same LHS and RHS are
// trivially equal.
auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(ctx(), 0), cst2);
EXPECT_EQ(expr, expr2);
// Stripes can be nested.
SDBMStripeExpr::get(expr, SDBMConstantExpr::get(ctx(), 4));
// Non-positive stripe factors are not allowed.
EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Neg) {
auto cst2 = SDBMConstantExpr::get(ctx(), 2);
auto var = SDBMSymbolExpr::get(ctx(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create negation expressions and query them.
auto expr = SDBMNegExpr::get(var);
EXPECT_EQ(expr.getVar(), var);
auto expr2 = SDBMNegExpr::get(stripe);
EXPECT_EQ(expr2.getVar(), stripe);
// Neg expressions are trivially comparable.
EXPECT_EQ(expr, SDBMNegExpr::get(var));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMNegExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Sum) {
auto cst2 = SDBMConstantExpr::get(ctx(), 2);
auto var = SDBMSymbolExpr::get(ctx(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create sum expressions and query them.
auto expr = SDBMSumExpr::get(var, cst2);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getRHS(), cst2);
auto expr2 = SDBMSumExpr::get(stripe, cst2);
EXPECT_EQ(expr2.getLHS(), stripe);
EXPECT_EQ(expr2.getRHS(), cst2);
// Sum expressions are trivially comparable.
EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMSumExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Diff) {
auto cst2 = SDBMConstantExpr::get(ctx(), 2);
auto var = SDBMSymbolExpr::get(ctx(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create sum expressions and query them.
auto expr = SDBMDiffExpr::get(var, stripe);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getRHS(), stripe);
auto expr2 = SDBMDiffExpr::get(stripe, var);
EXPECT_EQ(expr2.getLHS(), stripe);
EXPECT_EQ(expr2.getRHS(), var);
// Sum expressions are trivially comparable.
EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
} // end namespace