#!/usr/bin/env python3 """A script to generate FileCheck statements for mlir unit tests. This script is a utility to add FileCheck patterns to an mlir file. NOTE: The input .mlir is expected to be the output from the parser, not a stripped down variant. Example usage: $ generate-test-checks.py foo.mlir $ mlir-opt foo.mlir -transformation | generate-test-checks.py $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i $ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @' The script will heuristically generate CHECK/CHECK-LABEL commands for each line within the file. By default this script will also try to insert string substitution blocks for all SSA value names. If --source file is specified, the script will attempt to insert the generated CHECKs to the source file by looking for line positions matched by --source_delim_regex. The script is designed to make adding checks to a test case fast, it is *not* designed to be authoritative about what constitutes a good test! """ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import argparse import os # Used to advertise this file's name ("autogenerated_note"). import re import sys ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " ADVERT_END = """ // The script is designed to make adding checks to // a test case fast, it is *not* designed to be authoritative // about what constitutes a good test! The CHECK should be // minimized and named to reflect the test intent. """ # Regex command to match an SSA identifier. SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" SSA_RE = re.compile(SSA_RE_STR) # Regex matching the left-hand side of an assignment SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) # Regex matching attributes ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)' ATTR_RE = re.compile(ATTR_RE_STR) # Regex matching the left-hand side of an attribute definition ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*=' ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR) # Class used to generate and manage string substitution blocks for SSA value # names. class VariableNamer: def __init__(self, variable_names): self.scopes = [] self.name_counter = 0 # Number of variable names to still generate in parent scope self.generate_in_parent_scope_left = 0 # Parse variable names self.variable_names = [name.upper() for name in variable_names.split(',')] self.used_variable_names = set() # Generate the following 'n' variable names in the parent scope. def generate_in_parent_scope(self, n): self.generate_in_parent_scope_left = n # Generate a substitution name for the given ssa value name. def generate_name(self, source_variable_name): # Compute variable name variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' if variable_name == '': variable_name = "VAL_" + str(self.name_counter) self.name_counter += 1 # Scope where variable name is saved scope = len(self.scopes) - 1 if self.generate_in_parent_scope_left > 0: self.generate_in_parent_scope_left -= 1 scope = len(self.scopes) - 2 assert(scope >= 0) # Save variable if variable_name in self.used_variable_names: raise RuntimeError(variable_name + ': duplicate variable name') self.scopes[scope][source_variable_name] = variable_name self.used_variable_names.add(variable_name) return variable_name # Push a new variable name scope. def push_name_scope(self): self.scopes.append({}) # Pop the last variable name scope. def pop_name_scope(self): self.scopes.pop() # Return the level of nesting (number of pushed scopes). def num_scopes(self): return len(self.scopes) # Reset the counter and used variable names. def clear_names(self): self.name_counter = 0 self.used_variable_names = set() class AttributeNamer: def __init__(self, attribute_names): self.name_counter = 0 self.attribute_names = [name.upper() for name in attribute_names.split(',')] self.map = {} self.used_attribute_names = set() # Generate a substitution name for the given attribute name. def generate_name(self, source_attribute_name): # Compute FileCheck name attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else '' if attribute_name == '': attribute_name = "ATTR_" + str(self.name_counter) self.name_counter += 1 # Prepend global symbol attribute_name = '$' + attribute_name # Save attribute if attribute_name in self.used_attribute_names: raise RuntimeError(attribute_name + ': duplicate attribute name') self.map[source_attribute_name] = attribute_name self.used_attribute_names.add(attribute_name) return attribute_name # Get the saved substitution name for the given attribute name. If no name # has been generated for the given attribute yet, the source attribute name # itself is returned. def get_name(self, source_attribute_name): return self.map[source_attribute_name] if source_attribute_name in self.map else '?' # Return the number of SSA results in a line of type # %0, %1, ... = ... # The function returns 0 if there are no results. def get_num_ssa_results(input_line): m = SSA_RESULTS_RE.match(input_line) return m.group().count('%') if m else 0 # Process a line of input that has been split at each SSA identifier '%'. def process_line(line_chunks, variable_namer, strict_name_re=False): output_line = "" # Process the rest that contained an SSA value name. for chunk in line_chunks: m = SSA_RE.match(chunk) ssa_name = m.group(0) if m is not None else '' # Check if an existing variable exists for this name. variable = None for scope in variable_namer.scopes: variable = scope.get(ssa_name) if variable is not None: break # If one exists, then output the existing name. if variable is not None: output_line += "%[[" + variable + "]]" else: # Otherwise, generate a new variable. variable = variable_namer.generate_name(ssa_name) if strict_name_re: # Use stricter regexp for the variable name, if requested. # Greedy matching may cause issues with the generic '.*' # regexp when the checks are split across several # lines (e.g. for CHECK-SAME). output_line += "%[[" + variable + ":" + SSA_RE_STR + "]]" else: output_line += "%[[" + variable + ":.*]]" # Append the non named group. output_line += chunk[len(ssa_name) :] return output_line.rstrip() + "\n" # Process the source file lines. The source file doesn't have to be .mlir. def process_source_lines(source_lines, note, args): source_split_re = re.compile(args.source_delim_regex) source_segments = [[]] for line in source_lines: # Remove previous note. if line == note: continue # Remove previous CHECK lines. if line.find(args.check_prefix) != -1: continue # Segment the file based on --source_delim_regex. if source_split_re.search(line): source_segments.append([]) source_segments[-1].append(line + "\n") return source_segments def process_attribute_definition(line, attribute_namer, output): m = ATTR_DEF_RE.match(line) if m: attribute_name = attribute_namer.generate_name(m.group(1)) line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n' output.write(line) def process_attribute_references(line, attribute_namer): output_line = '' components = ATTR_RE.split(line) for component in components: m = ATTR_RE.match(component) if m: output_line += '#[[' + attribute_namer.get_name(m.group(1)) + ']]' output_line += component[len(m.group()):] else: output_line += component return output_line # Pre-process a line of input to remove any character sequences that will be # problematic with FileCheck. def preprocess_line(line): # Replace any double brackets, '[[' with escaped replacements. '[[' # corresponds to variable names in FileCheck. output_line = line.replace("[[", "{{\\[\\[}}") # Replace any single brackets that are followed by an SSA identifier, the # identifier will be replace by a variable; Creating the same situation as # above. output_line = output_line.replace("[%", "{{\\[}}%") return output_line def main(): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( "--check-prefix", default="CHECK", help="Prefix to use from check file." ) parser.add_argument( "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None ) parser.add_argument( "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) parser.add_argument( "--source", type=str, help="Print each CHECK chunk before each delimeter line in the source" "file, respectively. The delimeter lines are identified by " "--source_delim_regex.", ) parser.add_argument("--source_delim_regex", type=str, default="func @") parser.add_argument( "--starts_from_scope", type=int, default=1, help="Omit the top specified level of content. For example, by default " 'it omits "module {"', ) parser.add_argument("-i", "--inplace", action="store_true", default=False) parser.add_argument( "--variable_names", type=str, default='', help="Names to be used in FileCheck regular expression to represent SSA " "variables in the order they are encountered. Separate names with commas, " "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')") parser.add_argument( "--attribute_names", type=str, default='', help="Names to be used in FileCheck regular expression to represent " "attributes in the order they are defined. Separate names with commas," "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')") args = parser.parse_args() # Open the given input file. input_lines = [l.rstrip() for l in args.input] args.input.close() # Generate a note used for the generated check file. script_name = os.path.basename(__file__) autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END source_segments = None if args.source: source_segments = process_source_lines( [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args ) if args.inplace: assert args.output is None output = open(args.source, "w") elif args.output is None: output = sys.stdout else: output = args.output output_segments = [[]] # Namers variable_namer = VariableNamer(args.variable_names) attribute_namer = AttributeNamer(args.attribute_names) # Process lines for input_line in input_lines: if not input_line: continue # Check if this is an attribute definition and process it process_attribute_definition(input_line, attribute_namer, output) # Lines with blocks begin with a ^. These lines have a trailing comment # that needs to be stripped. lstripped_input_line = input_line.lstrip() is_block = lstripped_input_line[0] == "^" if is_block: input_line = input_line.rsplit("//", 1)[0].rstrip() cur_level = variable_namer.num_scopes() # If the line starts with a '}', pop the last name scope. if lstripped_input_line[0] == "}": variable_namer.pop_name_scope() cur_level = variable_namer.num_scopes() # If the line ends with a '{', push a new name scope. if input_line[-1] == "{": variable_namer.push_name_scope() if cur_level == args.starts_from_scope: output_segments.append([]) # Result SSA values must still be pushed to parent scope num_ssa_results = get_num_ssa_results(input_line) variable_namer.generate_in_parent_scope(num_ssa_results) # Omit lines at the near top level e.g. "module {". if cur_level < args.starts_from_scope: continue if len(output_segments[-1]) == 0: variable_namer.clear_names() # Preprocess the input to remove any sequences that may be problematic with # FileCheck. input_line = preprocess_line(input_line) # Process uses of attributes in this line input_line = process_attribute_references(input_line, attribute_namer) # Split the line at the each SSA value name. ssa_split = input_line.split("%") # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. if len(output_segments[-1]) != 0 or not ssa_split[0]: output_line = "// " + args.check_prefix + ": " # Pad to align with the 'LABEL' statements. output_line += " " * len("-LABEL") # Output the first line chunk that does not contain an SSA name. output_line += ssa_split[0] # Process the rest of the input line. output_line += process_line(ssa_split[1:], variable_namer) else: # Output the first line chunk that does not contain an SSA name for the # label. output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n" # Process the rest of the input line on separate check lines. for argument in ssa_split[1:]: output_line += "// " + args.check_prefix + "-SAME: " # Pad to align with the original position in the line. output_line += " " * len(ssa_split[0]) # Process the rest of the line. output_line += process_line( [argument], variable_namer, strict_name_re=True ) # Append the output line. output_segments[-1].append(output_line) output.write(autogenerated_note + "\n") # Write the output. if source_segments: assert len(output_segments) == len(source_segments) for check_segment, source_segment in zip(output_segments, source_segments): for line in check_segment: output.write(line) for line in source_segment: output.write(line) else: for segment in output_segments: output.write("\n") for output_line in segment: output.write(output_line) output.write("\n") output.close() if __name__ == "__main__": main()