Sean Silva cab8dda90f [mlir] Start splitting the tensor dialect out of std.
This starts by moving `std.extract_element` to `tensor.extract` (this
mirrors the naming of `vector.extract`).

Curiously, `std.extract_element` supposedly works on vectors as well,
and this patch removes that functionality. I would tend to do that in
separate patch, but I couldn't find any downstream users relying on
this, and the fact that we have `vector.extract` made it seem safe
enough to lump in here.

This also sets up the `tensor` dialect as a dependency of the `std`
dialect, as some ops that currently live in `std` depend on
`tensor.extract` via their canonicalization patterns.

Part of RFC: https://llvm.discourse.group/t/rfc-split-the-tensor-dialect-from-std/2347/2

Differential Revision: https://reviews.llvm.org/D92991
2020-12-11 13:50:55 -08:00

61 lines
2.3 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::tensor;
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ExtractOp op) {
// Verify the # indices match if we have a ranked type.
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
return op.emitOpError("incorrect number of indices for extract_element");
return success();
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
// The tensor operand must be a known constant.
Attribute tensor = operands.front();
if (!tensor)
return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
return splatTensor.getSplatValue();
// Otherwise, collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValue(indices);
return {};
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"