diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index eb046bc74229..8f07e43f847a 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -86,62 +86,100 @@ public: } }; -/// Printer hook for custom directive in assemblyFormat. +/// Printer hooks for custom directive in assemblyFormat. /// /// custom($values, $integers) /// custom($values, $integers, type($values)) /// -/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS -/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value -/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes` -/// is non-empty, it is expected to contain as many elements as `values` -/// indicating their types. This allows idiomatic printing of mixed value and -/// integer attributes in a list. E.g. -/// `[%arg0 : index, 7, 42, %arg42 : i32]`. +/// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS type +/// `I64ArrayAttr`. Print a list where each element is either: +/// 1. the static integer value in `integers`, if it's not `kDynamic` or, +/// 2. the next value in `values`, otherwise. +/// +/// If `valueTypes` is provided, the corresponding type of each dynamic value is +/// printed. Otherwise, the type is not printed. Each type must match the type +/// of the corresponding value in `values`. `valueTypes` is redundant for +/// printing as we can retrieve the types from the actual `values`. However, +/// `valueTypes` is needed for parsing and we must keep the API symmetric for +/// parsing and printing. The type for integer elements is `i64` by default and +/// never printed. +/// +/// Integer indices can also be scalable in the context of scalable vectors, +/// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in +/// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's +/// a scalable index. If `scalableFlags` is empty then assume that all indices +/// are non-scalable. +/// +/// Examples: +/// +/// * Input: `integers = [kDynamic, 7, 42, kDynamic]`, +/// `values = [%arg0, %arg42]` and +/// `valueTypes = [index, index]` +/// prints: +/// `[%arg0 : index, 7, 42, %arg42 : i32]` +/// +/// * Input: `integers = [kDynamic, 7, 42, kDynamic]`, +/// `values = [%arg0, %arg42]` and +/// `valueTypes = []` +/// prints: +/// `[%arg0, 7, 42, %arg42]` +/// +/// * Input: `integers = [2, 4, 8]`, +/// `values = []` and +/// `scalableFlags = [false, true, false]` +/// prints: +/// `[2, [4], 8]` /// -/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. -/// This notation is similar to how scalable dims are marked when defining -/// Vectors. For each value in `integers`, the corresponding `bool` in -/// `scalables` encodes whether it's a scalable index. If `scalableVals` is -/// empty then assume that all indices are non-scalable. void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, - ArrayRef integers, ArrayRef scalables, + ArrayRef integers, ArrayRef scalableFlags, TypeRange valueTypes = TypeRange(), AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); inline void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, TypeRange valueTypes = TypeRange(), AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { - return printDynamicIndexList(printer, op, values, integers, {}, valueTypes, - delimiter); + return printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, valueTypes, delimiter); } -/// Parser hook for custom directive in assemblyFormat. +/// Parser hooks for custom directive in assemblyFormat. /// /// custom($values, $integers) /// custom($values, $integers, type($values)) /// /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS -/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer -/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where -/// `kDynamic` encodes the position of SSA values. Add the parsed SSA values -/// to `values` in-order. If `valueTypes` is non-null, fill it with types -/// corresponding to values; otherwise the caller must handle the types. +/// type `I64ArrayAttr`. Parse a mixed list where each element is either a +/// static integer or an SSA value. Fill `integers` with the integer ArrayAttr, +/// where `kDynamic` encodes the position of SSA values. Add the parsed SSA +/// values to `values` in-order. /// -/// E.g. after parsing "[%arg0 : index, 7, 42, %arg42 : i32]": -/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42, -/// `kDynamic`]" -/// 2. `ssa` is filled with "[%arg0, %arg1]". +/// If `valueTypes` is provided, fill it with the types corresponding to each +/// value in `values`. Otherwise, the caller must handle the types and parsing +/// will fail if the type of the value is found (e.g., `[%arg0 : index, 3, %arg1 +/// : index]`). +/// +/// Integer indices can also be scalable in the context of scalable vectors, +/// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in +/// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's +/// a scalable index. +/// +/// Examples: +/// +/// * After parsing "[%arg0 : index, 7, 42, %arg42 : i32]": +/// 1. `result` is filled with `[kDynamic, 7, 42, kDynamic]` +/// 2. `values` is filled with "[%arg0, %arg1]". +/// 3. `scalableFlags` is filled with `[false, true, false]`. +/// +/// * After parsing `[2, [4], 8]`: +/// 1. `result` is filled with `[2, 4, 8]` +/// 2. `values` is empty. +/// 3. `scalableFlags` is filled with `[false, true, false]`. /// -/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. -/// This notation is similar to how scalable dims are marked when defining -/// Vectors. For each value in `integers`, the corresponding `bool` in -/// `scalableVals` encodes whether it's a scalable index. ParseResult parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, + DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); inline ParseResult parseDynamicIndexList( @@ -149,8 +187,8 @@ inline ParseResult parseDynamicIndexList( SmallVectorImpl &values, DenseI64ArrayAttr &integers, SmallVectorImpl *valueTypes = nullptr, AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { - DenseBoolArrayAttr scalableVals = {}; - return parseDynamicIndexList(parser, values, integers, scalableVals, + DenseBoolArrayAttr scalableFlags; + return parseDynamicIndexList(parser, values, integers, scalableFlags, valueTypes, delimiter); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp index ca33636336bf..57b5cce7bb13 100644 --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -113,7 +113,8 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) { void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, - ArrayRef scalables, TypeRange valueTypes, + ArrayRef scalableFlags, + TypeRange valueTypes, AsmParser::Delimiter delimiter) { char leftDelimiter = getLeftDelimiter(delimiter); char rightDelimiter = getRightDelimiter(delimiter); @@ -126,7 +127,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, unsigned dynamicValIdx = 0; unsigned scalableIndexIdx = 0; llvm::interleaveComma(integers, printer, [&](int64_t integer) { - if (!scalables.empty() && scalables[scalableIndexIdx]) + if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) printer << "["; if (ShapedType::isDynamic(integer)) { printer << values[dynamicValIdx]; @@ -136,7 +137,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, } else { printer << integer; } - if (!scalables.empty() && scalables[scalableIndexIdx]) + if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) printer << "]"; scalableIndexIdx++; @@ -148,7 +149,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, + DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter) { SmallVector integerVals; @@ -183,7 +184,7 @@ ParseResult mlir::parseDynamicIndexList( return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); - scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); + scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); return success(); }