mirror of
https://github.com/llvm/llvm-project.git
synced 2025-05-09 00:56:05 +00:00

This starts by moving `std.extract_element` to `tensor.extract` (this mirrors the naming of `vector.extract`). Curiously, `std.extract_element` supposedly works on vectors as well, and this patch removes that functionality. I would tend to do that in separate patch, but I couldn't find any downstream users relying on this, and the fact that we have `vector.extract` made it seem safe enough to lump in here. This also sets up the `tensor` dialect as a dependency of the `std` dialect, as some ops that currently live in `std` depend on `tensor.extract` via their canonicalization patterns. Part of RFC: https://llvm.discourse.group/t/rfc-split-the-tensor-dialect-from-std/2347/2 Differential Revision: https://reviews.llvm.org/D92991
61 lines
2.3 KiB
C++
61 lines
2.3 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tensor;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtractOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verify(ExtractOp op) {
|
|
// Verify the # indices match if we have a ranked type.
|
|
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
|
|
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
|
|
return op.emitOpError("incorrect number of indices for extract_element");
|
|
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|
// The tensor operand must be a known constant.
|
|
Attribute tensor = operands.front();
|
|
if (!tensor)
|
|
return {};
|
|
// If this is a splat elements attribute, simply return the value. All of the
|
|
// elements of a splat attribute are the same.
|
|
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
|
|
return splatTensor.getSplatValue();
|
|
|
|
// Otherwise, collect the constant indices into the tensor.
|
|
SmallVector<uint64_t, 8> indices;
|
|
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
|
if (!indice || !indice.isa<IntegerAttr>())
|
|
return {};
|
|
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
|
}
|
|
|
|
// If this is an elements attribute, query the value at the given indices.
|
|
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
|
|
if (elementsAttr && elementsAttr.isValidIndex(indices))
|
|
return elementsAttr.getValue(indices);
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
|