2024-05-14 12:54:48 -07:00

457 lines
18 KiB
C++

/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_MOSAIC_DIALECT_TPU_LAYOUT_H_
#define JAXLIB_MOSAIC_DIALECT_TPU_LAYOUT_H_
#include <array>
#include <cstdint>
#include <memory>
#include <numeric>
#include <optional>
#include <ostream>
#include <tuple>
#include "absl/log/check.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir::tpu {
// TODO(apaszke): Optimize this to encode the optional in the value itself
// and use a narrower type.
// An offset is nullopt when the value is replicated along sublanes or lanes.
using LayoutOffset = std::optional<int64_t>;
using LayoutOffsets = std::array<LayoutOffset, 2>;
enum class Direction { kSublanes, kLanes, kSubelements };
struct VRegDataBounds {
// TODO(tlongeri): Should get{Vector, Sublane}Mask take a Location?
virtual ~VRegDataBounds() = default;
// Determines whether all indices along a direction contain useful data.
virtual bool maskVariesAlong(Direction direction,
std::array<int64_t, 2> target_shape) const = 0;
bool isComplete(const std::array<int64_t, 2> target_shape) const {
return !maskVariesAlong(Direction::kSublanes, target_shape) &&
!maskVariesAlong(Direction::kLanes, target_shape) &&
!maskVariesAlong(Direction::kSubelements, target_shape);
}
// Constructs a vector mask value that is true iff the entry contains useful
// data.
//
// The returned value can be an int32 bitmask too, when the target does not
// have sufficiently expressive vector masks.
//
// Args:
// generation: The target TPU generation.
virtual FailureOr<TypedValue<VectorType>> getVectorMask(
OpBuilder &builder, Location loc, int generation,
std::array<int64_t, 2> target_shape) const = 0;
// Constructs a DenseBoolArrayAttr containing a sublane mask for the vreg.
//
// The sublane mask should never have true for sublanes that do not contain
// useful data, but having an unmasked sublane doesn't imply that all bits
// in that sublane are used to represent data (relevant for packed layouts).
virtual DenseBoolArrayAttr getSublaneMask(
MLIRContext *ctxt, std::array<int64_t, 2> target_shape) const = 0;
};
// Represents a rectangular region of data within a vector register.
//
// This class is very limited in its power and should only be used for 32-bit
// values with native tiling.
//
// Attributes:
// bounds: A TargetTuple of slices encoding the bounds of the rectangular
// data region.
// TODO(tlongeri): Can this be removed in favor of the more general
// TiledRectangularVregBounds?
class RectangularVregBounds : public VRegDataBounds {
public:
RectangularVregBounds(const std::array<int64_t, 2> starts,
const std::array<int64_t, 2> ends)
: starts_(starts), ends_(ends) {}
// See base class.
bool maskVariesAlong(Direction direction,
std::array<int64_t, 2> target_shape) const override;
// See base class.
FailureOr<TypedValue<VectorType>> getVectorMask(
OpBuilder &builder, Location loc, int generation,
std::array<int64_t, 2> target_shape) const override;
// See base class.
DenseBoolArrayAttr getSublaneMask(
MLIRContext *mlir_ctxt,
std::array<int64_t, 2> target_shape) const override;
private:
std::array<int64_t, 2> starts_;
std::array<int64_t, 2> ends_;
};
// VectorLayout describes a mapping of an arbitrarily sized values into vregs.
//
// First, let us consider the simplest case, when implicit_dim is None, bitwidth
// is 32, and tiling matches the vreg shape. Then, the two last dimensions of a
// vector are tiled over sublanes and lanes respectively. If a value is too
// large to fit within a single vreg, then it continues in another vector
// register. For example purposes, we assume that vregs have 4 sublanes and 5
// lanes from now on. A matrix with elements:
//
// a b c d e
// f g h i j
// k l m n o
// p q r s t
//
// laid out with offsets (1, 2) will use four vregs as follows:
//
// vreg 1 vreg 2
// . . . . . . . . . .
// . . a b c d e . . .
// . . f g h i j . . .
// . . k l m n o . . .
//
// vreg 3 vreg 4
// . . p q r s t . . .
// . . . . . . . . . .
// . . . . . . . . . .
// . . . . . . . . . .
//
// The dot character indicates padding. Nothing should be assumed about the
// value of those entries.
//
// If a value with this layout has rank >2, the leading dimensions will be
// unrolled over vregs. That is, the total number of vregs used to represent
// a value is equal to the product of all leading dimension sizes, and the
// number of vregs necessary to lay out the last two dimensions (as in the
// example).
//
// ---
//
// The implicit_dim attribute makes it possible to tile only the last dimension
// of a value, by implicitly inserting a singleton dimension that is tiled over
// sublanes (when implicit_dim is kMinor) or lanes (when implicit_dim is
// kSecondMinor).
//
// When the value has only one dimension, implicit_dim must be specified.
//
// ---
//
// The tiling attribute makes it possible to subdivide a single vector register
// into multiple subtiles that traverse the last dimension of a value. For
// example, consider vregs of shape (4, 5) an array:
//
// a b c d e f g h i j
// k l m n o p q r s t
//
// If we used a tiling of (4, 5), we would need two vregs to store this value,
// with the lower half of every register containing padding. But, if we use a
// tiling of (2, 5), both tiles fit into a single vreg:
//
// vreg 0
// a b c d e | tile 0
// k l m n o |
// f g h i j | tile 1
// p q r s t |
//
// Tiling is especially useful for compact storage of 1D values. Without it,
// we could use at most one sublane of every vector register. But, with a tiling
// of (1, 128) and implicit_dim being kSecondMinor, we can use all entries in a
// register to store long vectors.
//
// ---
//
// Finally, when the element bitwidth becomes smaller than 32, we use a two
// level tiling scheme, where elements of consecutive rows are packed into
// subelements. In TPU documentation this is often called a compressed layout.
// Note that this puts restrictions on the tile sizes, as they cannot have fewer
// rows than the packing factor (32 / bitwidth).
//
// Attributes:
// bitwidth: The bitwidth of the stored values.
// offsets: The coordinates of the first valid element. If an offset is
// replicated (nullopt), then any offset is valid as the value does not vary
// across sublanes or lanes respectively.
// tiling: The tiling used to lay out values (see the XLA docs). For values of
// bitwidth < 32, an implicit (32 / bitwidth, 1) tiling is appended to the
// one specified as an attribute.
// implicit_dim: If specified, the value has an implicit dim inserted in
// either minormost or second minormost position.
//
// Note: There is a special case when VectorLayout is used for an mlir::Value
// of i1 type. In this case, we use it to represent a vmask, which has a smaller
// bitwidth than a vreg. For these types, the packing() is accurate but the
// bitwidth() is a lie, and the i1 value is replicated for every bit.
// For example, if the vmask is 8 x 128 x 4 bits and packing() == 2, each 4-bit
// register contains two logical bool values which are represented as either b11
// or b00. Its usage is currently limited to MLIR arith.cmp and arith.select ops
// but we might want to split out a separate class if it gets used more widely.
class VectorLayout {
public:
enum class ImplicitDim {
kNone = 0, // To make if (implicit_dim) work.
// Also want to do dims[dims.size() - xla::to_underlying(implicit_dim)]
kMinor = 1,
kSecondMinor = 2,
};
VectorLayout(const int8_t bitwidth, const LayoutOffsets offsets,
const std::array<int64_t, 2> tiling,
const ImplicitDim implicit_dim = ImplicitDim::kNone)
: offsets_(offsets),
tiling_(tiling),
bitwidth_(bitwidth),
implicit_dim_(implicit_dim) {
// TODO(b/275751535): Allow more bitwidths.
CHECK(llvm::has_single_bit<unsigned>(bitwidth_) && bitwidth_ <= 32);
// Offsets should not exceed the tile size. The data always starts within
// the first tile of a vreg.
for (auto [o, t] : llvm::zip(offsets_, tiling_)) {
CHECK(!o || 0 <= *o && *o < t);
}
}
int8_t bitwidth() const { return bitwidth_; }
const LayoutOffsets &offsets() const { return offsets_; }
const std::array<int64_t, 2> &tiling() const { return tiling_; }
ImplicitDim implicit_dim() const { return implicit_dim_; }
int packing() const { return 32 / bitwidth_; }
// The number of minormost dimensions tiled by this layout.
int layout_rank() const { return 1 + (implicit_dim_ == ImplicitDim::kNone); }
bool operator==(const VectorLayout &other) const;
bool operator!=(const VectorLayout &other) const {
return !(*this == other);
}
// How many tiles fit in each vector register.
int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape) const {
const int64_t tile_elems = tiling_[0] * tiling_[1];
const int64_t vreg_capacity = packing() * target_shape[0] * target_shape[1];
const auto [tiles_per_vreg, rem] = std::div(vreg_capacity, tile_elems);
CHECK_EQ(rem, 0);
return tiles_per_vreg;
}
int64_t sublanesPerTile(const std::array<int64_t, 2> target_shape) const {
auto [sublanes_per_tile, rem] =
std::div(target_shape[0], tilesPerVreg(target_shape));
CHECK_EQ(rem, 0);
return sublanes_per_tile;
}
// Returns the size of a window contained in a single vreg.
//
// We never reuse the same vector register to store data of multiple rows,
// so only the minormost dimension can increase.
std::array<int64_t, 2> vregSlice(std::array<int64_t, 2> target_shape) const {
return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
}
void insertImplicit(SmallVector<int64_t> &vec, int64_t value) const {
CHECK_GE(vec.size(), layout_rank());
switch (implicit_dim_) {
case ImplicitDim::kNone:
break;
case ImplicitDim::kMinor:
case ImplicitDim::kSecondMinor:
vec.insert(vec.end() - (static_cast<int64_t>(implicit_dim_) - 1),
value);
break;
}
}
void eraseImplicit(SmallVector<int64_t> &vec) const {
CHECK_GE(vec.size(), 2);
switch (implicit_dim_) {
case ImplicitDim::kNone:
break;
case ImplicitDim::kMinor:
case ImplicitDim::kSecondMinor:
vec.erase(vec.end() - static_cast<int64_t>(implicit_dim_));
break;
}
}
SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;
SmallVector<int64_t> tileArrayImplicitShape(
ArrayRef<int64_t> shape, std::array<int64_t, 2> target_shape) const;
public:
// Returns the shape of ndarray of vregs needed to represent a value.
//
// All but the last two dimensions are unrolled over vregs. In the last two
// dims we need as many vregs as indicated by dividing the point at which
// the value ends (given by the start offset plus the dim size) divided by
// the respective vreg capacity in that dim (and a ceiling if non-integral).
// If a value is replicated, then any offset is valid and we pick 0 to
// minimize the number of vregs.
//
// Args:
// shape: The shape of the full vector this layout applies to.
SmallVector<int64_t> tileArrayShape(
ArrayRef<int64_t> shape, std::array<int64_t, 2> target_shape) const;
// Returns the bounds of the given tile that hold useful data.
//
// Arguments:
// full_shape: The shape of the full vector this layout applies to.
// ixs: The indices into an array of tiles representing the full vector
// (see tile_array_shape for bounds) selecting the tile for which the
// bounds are queried.
// allow_replicated: If False, no offset is allowed to be replicated. If
// True, offsets are allowed to be replicated, but the bounds will span
// the full dimension of the tile (i.e. potentially multiple repeats of
// the actual data).
//
// Returns:
// A TargetTuple of slices, indicating the span of useful data within the
// tile selected by idx.
std::unique_ptr<VRegDataBounds> tileDataBounds(
MLIRContext *mlir_ctxt, ArrayRef<int64_t> full_shape,
ArrayRef<int64_t> idxs, std::array<int64_t, 2> target_shape,
std::array<bool, 2> allow_replicated) const;
std::unique_ptr<VRegDataBounds> tileDataBounds(
MLIRContext *mlir_ctxt, ArrayRef<int64_t> full_shape,
ArrayRef<int64_t> idxs, std::array<int64_t, 2> target_shape,
bool allow_replicated = false) const {
return tileDataBounds(mlir_ctxt, full_shape, idxs, target_shape,
{allow_replicated, allow_replicated});
}
// True if every vector register has a layout without jumps.
//
// By without jumps we mean that traversing vregs over (sub)lanes always leads
// to a contiguous traversal of the (second) minormost dimension of data. This
// is only true for 32-bit types, since narrower types use two level tiling.
bool hasNaturalTopology(const std::array<int64_t, 2> target_shape) const {
return bitwidth_ == 32 && llvm::equal(tiling_, target_shape) &&
implicit_dim_ == ImplicitDim::kNone;
}
// True if every vector register has a natural "packed" topology.
//
// This is equivalent to has_natural_topology for 32-bit types, but
// generalizes it to narrower values with packed layouts too.
bool hasNativeTiling(std::array<int64_t, 2> target_shape) const;
// Returns true if the other layout is a special case of this one.
//
// In here, other is considered "a special case" when the set of vector
// register entries that represent a value in that layout is also the set of
// entries in which this stores the value. This is of course true for layouts
// that are equivalent, but it does not need to hold both ways. For example,
// a layout that implies the value does not change along an axis of the vector
// register is more general than the layout that picks a fixed starting point
// for the value and does not encode that assumption.
//
// The generalization relation is a non-strict partial order. You can think of
// it as a partial <= on vector layouts, but we don't overload operators since
// there's no clear way to decide where the bottom and top should be.
//
// Args:
// other: The layout compared against this.
// shape: A optional shape of the vector to which both layouts apply.
// If shape.data() == nullptr, then return whether it generalizes across
// all shapes.
// The generalization relation is larger than usual for some shapes. That
// is, if generalizes(other) then also generalizes(other, shape) for any
// shape, but that implication does not hold the other way around for some
// shapes.
bool generalizes(const VectorLayout &other, ArrayRef<int64_t> shape,
std::array<int64_t, 2> target_shape) const;
// Returns True if the two layouts are equivalent.
//
// That is, when all potential vector entries where the value can be stored
// (there might be multiple choices for some layouts!) are equal in both
// self and other.
//
// Args:
// other: The layout compared against self.
// shape: An optional shape of the vector to which both layouts apply. More
// layouts are considered equivalent when the shape is specified. Also see
// the docstring of the generalizes method.
bool equivalentTo(const VectorLayout &other, const ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
return generalizes(other, shape, target_shape) &&
other.generalizes(*this, shape, target_shape);
}
template <typename Stream>
void print(Stream &os) const;
static std::optional<VectorLayout> join(const VectorLayout &l,
const VectorLayout &r,
ArrayRef<int64_t> shape);
static std::optional<VectorLayout> parse(StringRef *data);
// Check conditions that depend on the target shape. Invariants that are
// independent of it are checked in the constructor.
bool isValid(const std::array<int64_t, 2> target_shape) const {
// Tiling should neatly divide the target shape, so that every vector
// register ends up having the same structure.
// Also, every tile should occupy a fixed number of sublanes.
auto [num_sublanes, rem] =
std::div(tiling_[0] * tiling_[1], packing() * target_shape[1]);
return rem == 0 && target_shape[0] % num_sublanes == 0;
}
private:
std::tuple<std::optional<int64_t>, std::optional<int64_t>, int64_t, int64_t,
int8_t, ImplicitDim>
as_tuple() const;
friend llvm::hash_code hash_value(const VectorLayout &layout);
LayoutOffsets offsets_;
std::array<int64_t, 2> tiling_;
int8_t bitwidth_;
ImplicitDim implicit_dim_;
};
using Layout = std::optional<VectorLayout>;
extern const Layout kNoLayout;
std::ostream &operator<<(std::ostream &os, const Layout &v);
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Layout &v);
llvm::hash_code hash_value(const VectorLayout &layout);
mlir::Diagnostic &operator<<(mlir::Diagnostic &diag, const Layout &v);
std::ostream &operator<<(std::ostream &os, VectorLayout::ImplicitDim dim);
std::optional<Layout> parseLayout(mlir::AsmParser &parser);
} // namespace mlir::tpu
#endif // JAXLIB_MOSAIC_DIALECT_TPU_LAYOUT_H_