[mlir] Introduce C API for PDL dialect types

This change introduces C API helper functions to work with PDL types.
Modification closely follow the format of the https://reviews.llvm.org/D116546.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D117221
This commit is contained in:
Denys Shabalin 2022-01-13 11:33:42 +01:00
parent edcac733dc
commit a8a2ee6331
10 changed files with 530 additions and 2 deletions

View File

@ -0,0 +1,72 @@
//===-- mlir-c/Dialect/PDL.h - C API for PDL Dialect --------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_PDL_H
#define MLIR_C_DIALECT_PDL_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PDL, pdl);
//===---------------------------------------------------------------------===//
// PDLType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type);
//===---------------------------------------------------------------------===//
// AttributeType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
// RangeType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
//===---------------------------------------------------------------------===//
// TypeType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
// ValueType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_DIALECT_QUANT_H

View File

@ -1,4 +1,4 @@
//===-- mlir-c/Dialect/LLVM.h - C API for LLVM --------------------*- C -*-===//
//===-- mlir-c/Dialect/Quant.h - C API for LLVM -------------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.

View File

@ -106,3 +106,12 @@ add_mlir_upstream_c_api_library(MLIRCAPIQuant
MLIRCAPIIR
MLIRQuant
)
add_mlir_upstream_c_api_library(MLIRCAPIPDL
PDL.cpp
PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRPDL
)

View File

@ -0,0 +1,85 @@
//===- PDL.cpp - C Interface for PDL dialect ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/PDL.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
using namespace mlir;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PDL, pdl, pdl::PDLDialect)
//===---------------------------------------------------------------------===//
// PDLType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLType(MlirType type) {
return unwrap(type).isa<pdl::PDLType>();
}
//===---------------------------------------------------------------------===//
// AttributeType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLAttributeType(MlirType type) {
return unwrap(type).isa<pdl::AttributeType>();
}
MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
return wrap(pdl::AttributeType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLOperationType(MlirType type) {
return unwrap(type).isa<pdl::OperationType>();
}
MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
return wrap(pdl::OperationType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// RangeType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLRangeType(MlirType type) {
return unwrap(type).isa<pdl::RangeType>();
}
MlirType mlirPDLRangeTypeGet(MlirType elementType) {
return wrap(pdl::RangeType::get(unwrap(elementType)));
}
//===---------------------------------------------------------------------===//
// TypeType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLTypeType(MlirType type) {
return unwrap(type).isa<pdl::TypeType>();
}
MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
return wrap(pdl::TypeType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// ValueType
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLValueType(MlirType type) {
return unwrap(type).isa<pdl::ValueType>();
}
MlirType mlirPDLValueTypeGet(MlirContext ctx) {
return wrap(pdl::ValueType::get(unwrap(ctx)));
}

View File

@ -1,4 +1,4 @@
//===- LLVM.cpp - C Interface for Quant dialect ---------------------------===//
//===- Quant.cpp - C Interface for Quant dialect --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -66,3 +66,11 @@ _add_capi_test_executable(mlir-capi-quant-test
MLIRCAPIRegistration
MLIRCAPIQuant
)
_add_capi_test_executable(mlir-capi-pdl-test
pdl.c
LINK_LIBS PRIVATE
MLIRCAPIIR
MLIRCAPIRegistration
MLIRCAPIPDL
)

332
mlir/test/CAPI/pdl.c Normal file
View File

@ -0,0 +1,332 @@
//===- pdl.c - Test of PDL dialect C API ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// RUN: mlir-capi-pdl-test 2>&1 | FileCheck %s
#include "mlir-c/Dialect/PDL.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IR.h"
#include <assert.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
// CHECK-LABEL: testAttributeType
void testAttributeType(MlirContext ctx) {
fprintf(stderr, "testAttributeType\n");
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!pdl.attribute"));
MlirType constructedType = mlirPDLAttributeTypeGet(ctx);
assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType");
assert(!mlirTypeIsNull(constructedType) && "couldn't construct PDLAttributeType");
// CHECK: parsedType isa PDLType: 1
fprintf(stderr, "parsedType isa PDLType: %d\n",
mlirTypeIsAPDLType(parsedType));
// CHECK: parsedType isa PDLAttributeType: 1
fprintf(stderr, "parsedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(parsedType));
// CHECK: parsedType isa PDLOperationType: 0
fprintf(stderr, "parsedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(parsedType));
// CHECK: parsedType isa PDLRangeType: 0
fprintf(stderr, "parsedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(parsedType));
// CHECK: parsedType isa PDLTypeType: 0
fprintf(stderr, "parsedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(parsedType));
// CHECK: parsedType isa PDLValueType: 0
fprintf(stderr, "parsedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(parsedType));
// CHECK: constructedType isa PDLType: 1
fprintf(stderr, "constructedType isa PDLType: %d\n",
mlirTypeIsAPDLType(constructedType));
// CHECK: constructedType isa PDLAttributeType: 1
fprintf(stderr, "constructedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(constructedType));
// CHECK: constructedType isa PDLOperationType: 0
fprintf(stderr, "constructedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(constructedType));
// CHECK: constructedType isa PDLRangeType: 0
fprintf(stderr, "constructedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(constructedType));
// CHECK: constructedType isa PDLTypeType: 0
fprintf(stderr, "constructedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(constructedType));
// CHECK: constructedType isa PDLValueType: 0
fprintf(stderr, "constructedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(constructedType));
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: !pdl.attribute
mlirTypeDump(parsedType);
// CHECK: !pdl.attribute
mlirTypeDump(constructedType);
fprintf(stderr, "\n\n");
}
// CHECK-LABEL: testOperationType
void testOperationType(MlirContext ctx) {
fprintf(stderr, "testOperationType\n");
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!pdl.operation"));
MlirType constructedType = mlirPDLOperationTypeGet(ctx);
assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType");
assert(!mlirTypeIsNull(constructedType) && "couldn't construct PDLAttributeType");
// CHECK: parsedType isa PDLType: 1
fprintf(stderr, "parsedType isa PDLType: %d\n",
mlirTypeIsAPDLType(parsedType));
// CHECK: parsedType isa PDLAttributeType: 0
fprintf(stderr, "parsedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(parsedType));
// CHECK: parsedType isa PDLOperationType: 1
fprintf(stderr, "parsedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(parsedType));
// CHECK: parsedType isa PDLRangeType: 0
fprintf(stderr, "parsedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(parsedType));
// CHECK: parsedType isa PDLTypeType: 0
fprintf(stderr, "parsedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(parsedType));
// CHECK: parsedType isa PDLValueType: 0
fprintf(stderr, "parsedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(parsedType));
// CHECK: constructedType isa PDLType: 1
fprintf(stderr, "constructedType isa PDLType: %d\n",
mlirTypeIsAPDLType(constructedType));
// CHECK: constructedType isa PDLAttributeType: 0
fprintf(stderr, "constructedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(constructedType));
// CHECK: constructedType isa PDLOperationType: 1
fprintf(stderr, "constructedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(constructedType));
// CHECK: constructedType isa PDLRangeType: 0
fprintf(stderr, "constructedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(constructedType));
// CHECK: constructedType isa PDLTypeType: 0
fprintf(stderr, "constructedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(constructedType));
// CHECK: constructedType isa PDLValueType: 0
fprintf(stderr, "constructedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(constructedType));
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: !pdl.operation
mlirTypeDump(parsedType);
// CHECK: !pdl.operation
mlirTypeDump(constructedType);
fprintf(stderr, "\n\n");
}
// CHECK-LABEL: testRangeType
void testRangeType(MlirContext ctx) {
fprintf(stderr, "testRangeType\n");
MlirType typeType = mlirPDLTypeTypeGet(ctx);
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!pdl.range<type>"));
MlirType constructedType = mlirPDLRangeTypeGet(typeType);
assert(!mlirTypeIsNull(typeType) && "couldn't get PDLTypeType");
assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType");
assert(!mlirTypeIsNull(constructedType) && "couldn't construct PDLAttributeType");
// CHECK: parsedType isa PDLType: 1
fprintf(stderr, "parsedType isa PDLType: %d\n",
mlirTypeIsAPDLType(parsedType));
// CHECK: parsedType isa PDLAttributeType: 0
fprintf(stderr, "parsedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(parsedType));
// CHECK: parsedType isa PDLOperationType: 0
fprintf(stderr, "parsedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(parsedType));
// CHECK: parsedType isa PDLRangeType: 1
fprintf(stderr, "parsedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(parsedType));
// CHECK: parsedType isa PDLTypeType: 0
fprintf(stderr, "parsedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(parsedType));
// CHECK: parsedType isa PDLValueType: 0
fprintf(stderr, "parsedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(parsedType));
// CHECK: constructedType isa PDLType: 1
fprintf(stderr, "constructedType isa PDLType: %d\n",
mlirTypeIsAPDLType(constructedType));
// CHECK: constructedType isa PDLAttributeType: 0
fprintf(stderr, "constructedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(constructedType));
// CHECK: constructedType isa PDLOperationType: 0
fprintf(stderr, "constructedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(constructedType));
// CHECK: constructedType isa PDLRangeType: 1
fprintf(stderr, "constructedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(constructedType));
// CHECK: constructedType isa PDLTypeType: 0
fprintf(stderr, "constructedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(constructedType));
// CHECK: constructedType isa PDLValueType: 0
fprintf(stderr, "constructedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(constructedType));
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: !pdl.range<type>
mlirTypeDump(parsedType);
// CHECK: !pdl.range<type>
mlirTypeDump(constructedType);
fprintf(stderr, "\n\n");
}
// CHECK-LABEL: testTypeType
void testTypeType(MlirContext ctx) {
fprintf(stderr, "testTypeType\n");
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!pdl.type"));
MlirType constructedType = mlirPDLTypeTypeGet(ctx);
assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType");
assert(!mlirTypeIsNull(constructedType) && "couldn't construct PDLAttributeType");
// CHECK: parsedType isa PDLType: 1
fprintf(stderr, "parsedType isa PDLType: %d\n",
mlirTypeIsAPDLType(parsedType));
// CHECK: parsedType isa PDLAttributeType: 0
fprintf(stderr, "parsedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(parsedType));
// CHECK: parsedType isa PDLOperationType: 0
fprintf(stderr, "parsedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(parsedType));
// CHECK: parsedType isa PDLRangeType: 0
fprintf(stderr, "parsedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(parsedType));
// CHECK: parsedType isa PDLTypeType: 1
fprintf(stderr, "parsedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(parsedType));
// CHECK: parsedType isa PDLValueType: 0
fprintf(stderr, "parsedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(parsedType));
// CHECK: constructedType isa PDLType: 1
fprintf(stderr, "constructedType isa PDLType: %d\n",
mlirTypeIsAPDLType(constructedType));
// CHECK: constructedType isa PDLAttributeType: 0
fprintf(stderr, "constructedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(constructedType));
// CHECK: constructedType isa PDLOperationType: 0
fprintf(stderr, "constructedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(constructedType));
// CHECK: constructedType isa PDLRangeType: 0
fprintf(stderr, "constructedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(constructedType));
// CHECK: constructedType isa PDLTypeType: 1
fprintf(stderr, "constructedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(constructedType));
// CHECK: constructedType isa PDLValueType: 0
fprintf(stderr, "constructedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(constructedType));
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: !pdl.type
mlirTypeDump(parsedType);
// CHECK: !pdl.type
mlirTypeDump(constructedType);
fprintf(stderr, "\n\n");
}
// CHECK-LABEL: testValueType
void testValueType(MlirContext ctx) {
fprintf(stderr, "testValueType\n");
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!pdl.value"));
MlirType constructedType = mlirPDLValueTypeGet(ctx);
assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType");
assert(!mlirTypeIsNull(constructedType) && "couldn't construct PDLAttributeType");
// CHECK: parsedType isa PDLType: 1
fprintf(stderr, "parsedType isa PDLType: %d\n",
mlirTypeIsAPDLType(parsedType));
// CHECK: parsedType isa PDLAttributeType: 0
fprintf(stderr, "parsedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(parsedType));
// CHECK: parsedType isa PDLOperationType: 0
fprintf(stderr, "parsedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(parsedType));
// CHECK: parsedType isa PDLRangeType: 0
fprintf(stderr, "parsedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(parsedType));
// CHECK: parsedType isa PDLTypeType: 0
fprintf(stderr, "parsedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(parsedType));
// CHECK: parsedType isa PDLValueType: 1
fprintf(stderr, "parsedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(parsedType));
// CHECK: constructedType isa PDLType: 1
fprintf(stderr, "constructedType isa PDLType: %d\n",
mlirTypeIsAPDLType(constructedType));
// CHECK: constructedType isa PDLAttributeType: 0
fprintf(stderr, "constructedType isa PDLAttributeType: %d\n",
mlirTypeIsAPDLAttributeType(constructedType));
// CHECK: constructedType isa PDLOperationType: 0
fprintf(stderr, "constructedType isa PDLOperationType: %d\n",
mlirTypeIsAPDLOperationType(constructedType));
// CHECK: constructedType isa PDLRangeType: 0
fprintf(stderr, "constructedType isa PDLRangeType: %d\n",
mlirTypeIsAPDLRangeType(constructedType));
// CHECK: constructedType isa PDLTypeType: 0
fprintf(stderr, "constructedType isa PDLTypeType: %d\n",
mlirTypeIsAPDLTypeType(constructedType));
// CHECK: constructedType isa PDLValueType: 1
fprintf(stderr, "constructedType isa PDLValueType: %d\n",
mlirTypeIsAPDLValueType(constructedType));
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: !pdl.value
mlirTypeDump(parsedType);
// CHECK: !pdl.value
mlirTypeDump(constructedType);
fprintf(stderr, "\n\n");
}
int main() {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__pdl__(), ctx);
testAttributeType(ctx);
testOperationType(ctx);
testRangeType(ctx);
testTypeType(ctx);
testValueType(ctx);
return EXIT_SUCCESS;
}

View File

@ -75,6 +75,7 @@ set(MLIR_TEST_DEPENDS
mlir-capi-pass-test
mlir-capi-sparse-tensor-test
mlir-capi-quant-test
mlir-capi-pdl-test
mlir-cpu-runner
mlir-linalg-ods-yaml-gen
mlir-lsp-server

View File

@ -68,6 +68,7 @@ tools = [
'mlir-capi-pass-test',
'mlir-capi-sparse-tensor-test',
'mlir-capi-quant-test',
'mlir-capi-pdl-test',
'mlir-cpu-runner',
'mlir-linalg-ods-yaml-gen',
'mlir-reduce',

View File

@ -514,6 +514,26 @@ mlir_c_api_cc_library(
],
)
mlir_c_api_cc_library(
name = "CAPIPDL",
srcs = [
"lib/CAPI/Dialect/PDL.cpp",
],
hdrs = [
"include/mlir-c/Dialect/PDL.h",
],
header_deps = [
":CAPIIRHeaders",
],
includes = ["include"],
deps = [
":CAPIIR",
":PDLDialect",
":PDLOpsIncGen",
":PDLTypesIncGen",
],
)
mlir_c_api_cc_library(
name = "CAPIConversion",
srcs = ["lib/CAPI/Conversion/Passes.cpp"],