from typing import Optional, List # Base class for an instruction. To implement a basic instruction that doesn't # impact the control-flow, create a new class inheriting from this. class Instruction: # Contains the name of the output register, if any. _result: Optional[str] # Contains the instruction opcode. _opcode: str # Contains all the instruction operands, except result and opcode. _operands: List[str] def __init__(self, line: str): self.line = line tokens = line.split() if len(tokens) > 1 and tokens[1] == "=": self._result = tokens[0] self._opcode = tokens[2] self._operands = tokens[3:] if len(tokens) > 2 else [] else: self._result = None self._opcode = tokens[0] self._operands = tokens[1:] if len(tokens) > 1 else [] def __str__(self): if self._result is None: return f" {self._opcode} {self._operands}" return f"{self._result:3} = {self._opcode} {self._operands}" # Returns the instruction opcode. def opcode(self) -> str: return self._opcode # Returns the instruction operands. def operands(self) -> List[str]: return self._operands # Returns the instruction output register. Calling this function is # only allowed if has_output_register() is true. def output_register(self) -> str: assert self._result is not None return self._result # Returns true if this function has an output register. False otherwise. def has_output_register(self) -> bool: return self._result is not None # This function is used to initialize state related to this instruction # before module execution begins. For example, global Input variables # can use this to store the lane ID into the register. def static_execution(self, lane): pass # This function is called everytime this instruction is executed by a # tangle. This function should not be directly overriden, instead see # _impl and _advance_ip. def runtime_execution(self, module, lane): self._impl(module, lane) self._advance_ip(module, lane) # This function needs to be overriden if your instruction can be executed. # It implements the logic of the instruction. # 'Static' instructions like OpConstant should not override this since # they are not supposed to be executed at runtime. def _impl(self, module, lane): raise RuntimeError(f"Unimplemented instruction {self}") # By default, IP is incremented to point to the next instruction. # If the instruction modifies IP (like OpBranch), this must be overridden. def _advance_ip(self, module, lane): lane.set_ip(lane.ip() + 1) # Those are parsed, but never executed. class OpEntryPoint(Instruction): pass class OpFunction(Instruction): pass class OpFunctionEnd(Instruction): pass class OpLabel(Instruction): pass class OpVariable(Instruction): pass class OpName(Instruction): def name(self) -> str: return self._operands[1][1:-1] def decoratedRegister(self) -> str: return self._operands[0] # The only decoration we use if the BuiltIn one to initialize the values. class OpDecorate(Instruction): def static_execution(self, lane): if self._operands[1] == "LinkageAttributes": return assert ( self._operands[1] == "BuiltIn" and self._operands[2] == "SubgroupLocalInvocationId" ) lane.set_register(self._operands[0], lane.tid()) # Constants class OpConstant(Instruction): def static_execution(self, lane): lane.set_register(self._result, int(self._operands[1])) class OpConstantTrue(OpConstant): def static_execution(self, lane): lane.set_register(self._result, True) class OpConstantFalse(OpConstant): def static_execution(self, lane): lane.set_register(self._result, False) class OpConstantComposite(OpConstant): def static_execution(self, lane): result = [] for op in self._operands[1:]: result.append(lane.get_register(op)) lane.set_register(self._result, result) # Control flow instructions class OpFunctionCall(Instruction): def _impl(self, module, lane): pass def _advance_ip(self, module, lane): entry = module.get_function_entry(self._operands[1]) lane.do_call(entry, self._result) class OpReturn(Instruction): def _impl(self, module, lane): pass def _advance_ip(self, module, lane): lane.do_return(None) class OpReturnValue(Instruction): def _impl(self, module, lane): pass def _advance_ip(self, module, lane): lane.do_return(lane.get_register(self._operands[0])) class OpBranch(Instruction): def _impl(self, module, lane): pass def _advance_ip(self, module, lane): lane.set_ip(module.get_bb_entry(self._operands[0])) pass class OpBranchConditional(Instruction): def _impl(self, module, lane): pass def _advance_ip(self, module, lane): condition = lane.get_register(self._operands[0]) if condition: lane.set_ip(module.get_bb_entry(self._operands[1])) else: lane.set_ip(module.get_bb_entry(self._operands[2])) class OpSwitch(Instruction): def _impl(self, module, lane): pass def _advance_ip(self, module, lane): value = lane.get_register(self._operands[0]) default_label = self._operands[1] i = 2 while i < len(self._operands): imm = int(self._operands[i]) label = self._operands[i + 1] if value == imm: lane.set_ip(module.get_bb_entry(label)) return i += 2 lane.set_ip(module.get_bb_entry(default_label)) class OpUnreachable(Instruction): def _impl(self, module, lane): raise RuntimeError("This instruction should never be executed.") # Convergence instructions class MergeInstruction(Instruction): def merge_location(self): return self._operands[0] def continue_location(self): return None if len(self._operands) < 3 else self._operands[1] def _impl(self, module, lane): lane.handle_convergence_header(self) class OpLoopMerge(MergeInstruction): pass class OpSelectionMerge(MergeInstruction): pass # Other instructions class OpBitcast(Instruction): def _impl(self, module, lane): # TODO: find out the type from the defining instruction. # This can only work for DXC. if self._operands[0] == "%int": lane.set_register(self._result, int(lane.get_register(self._operands[1]))) else: raise RuntimeError("Unsupported OpBitcast operand") class OpAccessChain(Instruction): def _impl(self, module, lane): # Python dynamic types allows me to simplify. As long as the SPIR-V # is legal, this should be fine. # Note: SPIR-V structs are stored as tuples value = lane.get_register(self._operands[1]) for operand in self._operands[2:]: value = value[lane.get_register(operand)] lane.set_register(self._result, value) class OpCompositeConstruct(Instruction): def _impl(self, module, lane): output = [] for op in self._operands[1:]: output.append(lane.get_register(op)) lane.set_register(self._result, output) class OpCompositeExtract(Instruction): def _impl(self, module, lane): value = lane.get_register(self._operands[1]) output = value for op in self._operands[2:]: output = output[int(op)] lane.set_register(self._result, output) class OpStore(Instruction): def _impl(self, module, lane): lane.set_register(self._operands[0], lane.get_register(self._operands[1])) class OpLoad(Instruction): def _impl(self, module, lane): lane.set_register(self._result, lane.get_register(self._operands[1])) class OpIAdd(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) RHS = lane.get_register(self._operands[2]) lane.set_register(self._result, LHS + RHS) class OpISub(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) RHS = lane.get_register(self._operands[2]) lane.set_register(self._result, LHS - RHS) class OpIMul(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) RHS = lane.get_register(self._operands[2]) lane.set_register(self._result, LHS * RHS) class OpLogicalNot(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) lane.set_register(self._result, not LHS) class _LessThan(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) RHS = lane.get_register(self._operands[2]) lane.set_register(self._result, LHS < RHS) class _GreaterThan(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) RHS = lane.get_register(self._operands[2]) lane.set_register(self._result, LHS > RHS) class OpSLessThan(_LessThan): pass class OpULessThan(_LessThan): pass class OpSGreaterThan(_GreaterThan): pass class OpUGreaterThan(_GreaterThan): pass class OpIEqual(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) RHS = lane.get_register(self._operands[2]) lane.set_register(self._result, LHS == RHS) class OpINotEqual(Instruction): def _impl(self, module, lane): LHS = lane.get_register(self._operands[1]) RHS = lane.get_register(self._operands[2]) lane.set_register(self._result, LHS != RHS) class OpPhi(Instruction): def _impl(self, module, lane): previousBBName = lane.get_previous_bb_name() i = 1 while i < len(self._operands): label = self._operands[i + 1] if label == previousBBName: lane.set_register(self._result, lane.get_register(self._operands[i])) return i += 2 raise RuntimeError("previousBB not in the OpPhi _operands") class OpSelect(Instruction): def _impl(self, module, lane): condition = lane.get_register(self._operands[1]) value = lane.get_register(self._operands[2 if condition else 3]) lane.set_register(self._result, value) # Wave intrinsics class OpGroupNonUniformBroadcastFirst(Instruction): def _impl(self, module, lane): assert lane.get_register(self._operands[1]) == 3 if lane.is_first_active_lane(): lane.broadcast_register(self._result, lane.get_register(self._operands[2])) class OpGroupNonUniformElect(Instruction): def _impl(self, module, lane): lane.set_register(self._result, lane.is_first_active_lane())