From d30c0221cf5aa36c079b7cc0d36fb89f7b32149b Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 16 Jun 2022 20:01:54 -0700 Subject: [PATCH] [mlir] Split MLProgram global load and store to Graph variants * Split ops into X_graph variants as discussed; * Remove tokens from non-Graph region variants and rely on side-effect modelling there while removing side-effect modelling from Graph variants and relying on explicit ordering there; * Make tokens required to be produced by Graph variants - but kept explicit token type specification given previous discussion on this potentially being configurable in future; This results in duplicating some code. I considered adding helper functions but decided against adding an abstraction there early given size of duplication and creating accidental coupling. Differential Revision: https://reviews.llvm.org/D127813 --- .../mlir/Dialect/MLProgram/IR/MLProgramOps.td | 106 ++++++++++++++++-- .../lib/Dialect/MLProgram/IR/MLProgramOps.cpp | 65 ++++++++++- mlir/test/Dialect/MLProgram/invalid.mlir | 14 +++ mlir/test/Dialect/MLProgram/ops.mlir | 8 +- 4 files changed, 174 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td index d820573b200a..69b1eab379b3 100644 --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -171,7 +171,8 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [ advanced cases. This op is side effecting and may not be valid to use in graph regions - without additional consideration to evaluation order constraints. + without additional consideration to evaluation order constraints. See + `global_load_graph` for op which allows for explicit ordering constraints. Example: @@ -181,16 +182,14 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [ }]; let arguments = (ins - Arg:$global, - Variadic:$consumeTokens + Arg:$global ); let results = (outs - AnyType:$result, - Optional:$produceToken + AnyType:$result ); let assemblyFormat = [{ - $global `` custom($consumeTokens, type($produceToken)) `:` type($result) attr-dict + $global `:` type($result) attr-dict }]; let extraClassDeclaration = [{ @@ -238,6 +237,52 @@ def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [ }]; } +//===----------------------------------------------------------------------===// +// GlobalLoadGraphOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalLoadGraphOp : MLProgram_Op<"global_load_graph", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Direct load of a mutable value from a global in Graph region"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized load from a global + that may be mutable. + + It is fully expected that these constraints are not suitable for all + situations, and alternative ops should be defined and used for more advanced + cases. + + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. + + Example: + + ```mlir + %0, %cstr = ml_program.global_load_graph @foobar + ordering (%token -> !ml_program.token) : tensor + ``` + }]; + + let arguments = (ins + Arg:$global, + Variadic:$consumeTokens + ); + let results = (outs + AnyType:$result, + MLProgram_TokenType:$produceToken + ); + + let assemblyFormat = [{ + $global `` custom($consumeTokens, type($produceToken)) `:` type($result) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + //===----------------------------------------------------------------------===// // GlobalStoreOp //===----------------------------------------------------------------------===// @@ -254,23 +299,66 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [ all situations, and alternative ops should be defined and used for more advanced cases. + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. See + `global_store_graph` for op which allows for explicit ordering constraints. + + Example: + + ```mlir + ml_program.global_store @foobar = %0 : tensor + ``` + }]; + + let arguments = (ins + Arg:$global, + AnyType:$value + ); + + let assemblyFormat = [{ + $global `=` $value `:` type($value) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + +//===----------------------------------------------------------------------===// +// GlobalStoreGraphOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalStoreGraphOp : MLProgram_Op<"global_store_graph", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Direct store of a value into a mutable global"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized store to a mutable + global. + + It is fully expected that these constraints are not suitable for + all situations, and alternative ops should be defined and used for more + advanced cases. + This op is side effecting and may not be valid to use in graph regions without additional consideration to evaluation order constraints. Example: ```mlir - ml_program.global_store @foobar = %0 : tensor + %token = ml_program.global_store @foobar = %0 : tensor + ordering (%in_token -> !ml_program.token) : tensor ``` }]; let arguments = (ins - Arg:$global, + Arg:$global, AnyType:$value, Variadic:$consumeTokens ); let results = (outs - Optional:$produceToken + MLProgram_TokenType:$produceToken ); let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp index 9411ea17b816..2f1e4b93a6ac 100644 --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -18,12 +18,11 @@ using namespace mlir::ml_program; //===----------------------------------------------------------------------===// /// Parse and print an ordering clause for a variadic of consuming tokens -/// and an optional producing token. +/// and an producing token. /// /// Syntax: /// ordering(%0, %1 -> !ml_program.token) /// ordering(() -> !ml_program.token) -/// ordering(%0, %1) /// /// If both the consuming and producing token are not present on the op, then /// the clause prints nothing. @@ -46,10 +45,11 @@ static ParseResult parseTokenOrdering( return failure(); } - // Parse optional producer token. - if (succeeded(parser.parseOptionalArrow())) - if (failed(parser.parseType(produceTokenType))) - return failure(); + // Parse producer token. + if (failed(parser.parseArrow())) + return failure(); + if (failed(parser.parseType(produceTokenType))) + return failure(); if (failed(parser.parseRParen())) return failure(); @@ -220,6 +220,30 @@ GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// GlobalLoadGraphOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (referrent.getType() != getResult().getType()) { + return emitOpError() << "cannot load from global typed " + << referrent.getType() << " as " + << getResult().getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // GlobalStoreOp //===----------------------------------------------------------------------===// @@ -249,6 +273,35 @@ GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// GlobalStoreGraphOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (!referrent.getIsMutable()) { + return emitOpError() << "cannot store to an immutable global " + << getGlobal(); + } + + if (referrent.getType() != getValue().getType()) { + return emitOpError() << "cannot store to a global typed " + << referrent.getType() << " from " + << getValue().getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir index e193c6d58f11..79725a9bfbe1 100644 --- a/mlir/test/Dialect/MLProgram/invalid.mlir +++ b/mlir/test/Dialect/MLProgram/invalid.mlir @@ -96,3 +96,17 @@ ml_program.func @store_immutable(%arg0: i64) { ml_program.global_store @var = %arg0 : i64 ml_program.return } + +// ----- + +ml_program.global private mutable @global_mutable_undef : tensor +ml_program.subgraph @global_load_store_tokens() -> (tensor, !ml_program.token) { + %token1 = ml_program.token + %0, %token2 = ml_program.global_load_graph @global_mutable_undef + ordering(() -> !ml_program.token) : tensor + %token3 = ml_program.global_store_graph @global_mutable_undef = %0 + // expected-error @+1 {{expected '->'}} + ordering(%token1, %token2) : tensor + + ml_program.output %0, %token3 : tensor, !ml_program.token +} diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir index ca2d72afb6c6..9a48497a3efc 100644 --- a/mlir/test/Dialect/MLProgram/ops.mlir +++ b/mlir/test/Dialect/MLProgram/ops.mlir @@ -45,12 +45,12 @@ ml_program.func @global_load_store() { // CHECK-LABEL: @global_load_store_tokens ml_program.subgraph @global_load_store_tokens() -> (tensor, !ml_program.token) { %token1 = ml_program.token - %0, %token2 = ml_program.global_load @global_mutable_undef + %0, %token2 = ml_program.global_load_graph @global_mutable_undef ordering(() -> !ml_program.token) : tensor - %token3 = ml_program.global_store @global_mutable_undef = %0 + %token3 = ml_program.global_store_graph @global_mutable_undef = %0 ordering(%token1, %token2 -> !ml_program.token) : tensor - ml_program.global_store @global_mutable_undef = %0 - ordering(%token3) : tensor + %token4 = ml_program.global_store_graph @global_mutable_undef = %0 + ordering(%token3 -> !ml_program.token) : tensor ml_program.output %0, %token3 : tensor, !ml_program.token }