mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 09:56:07 +00:00

Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic are as follows: By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>) outs (%arg2: memref<2x3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 https://github.com/llvm/llvm-project/pull/115319