Source code for microprobe.passes.register

# Copyright 2011-2021 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":mod:`microprobe.passes.register` module

"""

# Futures
from __future__ import absolute_import, print_function

# Built-in modules
import itertools
import random
import re
from typing import Union

# Third party modules

# Own modules
import microprobe.passes
import microprobe.utils.distrib
from microprobe.code.address import Address, MemoryValue
from microprobe.exceptions import \
    MicroprobeCodeGenerationError, MicroprobeValueError
from microprobe.utils.logger import get_logger
from microprobe.utils.misc import OrderedDict

# Constants
LOG = get_logger(__name__)
__all__ = [
    'DefaultRegisterAllocationPass', 'CycleMinimalAllocationPass',
    'FixRegistersPass', 'NoHazardsAllocationPass', 'RandomAllocationPass'
]

# Functions


# Classes
[docs] class DefaultRegisterAllocationPass(microprobe.passes.Pass): """DefaultRegisterAllocationPass pass. """
[docs] def __init__(self, minimize=False, value=None, dd: Union[int, float] = 0, relax=False): """ :param minimize: (Default value = False) :param value: (Default value = None) :param dd: (Default value = 1) """ super(DefaultRegisterAllocationPass, self).__init__() self._min = minimize if self._min: raise NotImplementedError("Feature changed need to be" " reimplemented") if isinstance(dd, int): self._dd = lambda: dd # Dependency distance elif isinstance(dd, float): self._dd = microprobe.utils.distrib.discrete_average(dd) elif callable(dd): self._dd = dd else: raise MicroprobeValueError("Invalid parameter") self._relax = relax if value is not None: self._immediate = value if value == "zero": self._immediate = 0 else: self._immediate = "random" self._description = "Default register allocation (required always to" \ " set any remaining operand). minimize=%s, " \ "value=%s, dd=%s" % (minimize, value, dd)
def __call__(self, building_block, target): """ :param building_block: :param target: """ allregs = target.registers lastdefined = {} lastused = {} rregs = set(building_block.context.reserved_registers) # TODO: All this pass has to be reimplemented for reg in allregs.values(): if reg in rregs: # TODO: rregs can be used but only as input, now we discard # them for input and output continue if reg.type not in lastused: lastused[reg.type] = OrderedDict() lastdefined[reg.type] = OrderedDict() lastdefined[reg.type][reg] = 0 lastused[reg.type][reg] = 1 idx = 1 for bbl in building_block.cfg.bbls: for instr in bbl.instrs: dependency_ok = False used = [] defined = [] distance = self._dd() for operand in instr.operands(): if operand.value is not None: # operand already set LOG.debug("Operand already set") pass elif operand.type.immediate: LOG.debug("Operand immediate") if self._immediate != "random": try: svalue = self._immediate while callable(svalue): svalue = svalue() operand.set_value(svalue) except MicroprobeValueError: LOG.warning( "Operand '%s' in instruction '%s' " "not modeled properly. Tried to " "'%s' value ...", operand.type, instr.name, svalue) value = list(operand.type.values())[0] operand.set_value(value) LOG.warning("Operand set to: '%s'", value) else: operand.set_value(operand.type.random_value()) elif operand.type.address_relative: LOG.warning( "Operand '%s' in instruction '%s' " "not modeled properly", operand.type, instr.name) operand.set_value(list(operand.type.values())[0]) else: LOG.debug("Operand is register") if operand.is_input and distance > 0 and \ not dependency_ok and idx > distance: LOG.debug("Setting dependency distance") regs = list(operand.type.values()) if len(regs) == 1: # There is not option valid_values = regs else: valid_values = [] for reg in regs: if len( rregs.intersection( operand.type.access( reg))) == 0: valid_values.append(reg) if (len(valid_values) == 0 and operand.is_input and not operand.is_output): # Try to use reserved values # if the operand is read only valid_values = list(operand.type.values()) LOG.debug("Current idx: %d", idx) LOG.debug("Requested idx: %d", idx - distance) reg = microprobe.utils.distrib.sort_by_distance( valid_values, lastdefined[regs[0].type], lastused[regs[0].type], idx - distance, instr, idx) LOG.debug("%s selected", reg) if reg in lastdefined[regs[0].type]: LOG.debug("Last defined: %s", lastdefined[regs[0].type][reg]) if reg in lastused[regs[0].type]: LOG.debug("Last used: %s", lastused[regs[0].type][reg]) dependency_ok = True else: LOG.debug("No need for distance") regs = list(operand.type.values()) valid_values = [] if len(regs) == 1: # There is not option valid_values = regs else: for reg in regs: if len( rregs.intersection( operand.type.access( reg))) == 0: valid_values.append(reg) if (len(valid_values) == 0 and operand.is_input and not operand.is_output): # Try to use reserved values # if the operand is reading only valid_values = list(operand.type.values()) if len(valid_values) == 0: LOG.critical("Instruction: %s", instr) LOG.critical("Operand: %s", operand) LOG.critical("Possible all values: %s", sorted(operand.type.values())) LOG.critical("Possible values: %s", sorted(regs)) LOG.critical("Reserved values: %s", sorted(rregs)) raise MicroprobeCodeGenerationError( "Unable to find proper operand" " values for register allocation.") if regs[0].type not in lastused: raise MicroprobeCodeGenerationError( "Unable to find proper operand" " values for register allocation." " No registers of type '%s' " "available." % regs[0].type) reg = microprobe.utils.distrib.sort_by_usage( valid_values, lastused[regs[0].type], lastdefined[regs[0].type]) operand.set_value(reg) if self._min is False: for reg in operand.uses(): LOG.debug("Updating usage of %s", reg) LOG.debug(list(lastused[reg.type].keys())) if reg in lastused[reg.type]: del lastused[reg.type][reg] lastused[reg.type][reg] = idx used.append(reg) LOG.debug(list(lastused[reg.type].keys())) if operand.is_output: for reg in operand.sets(): LOG.debug("Updating definition of %s", reg) if reg in lastdefined[reg.type]: del lastdefined[reg.type][reg] lastdefined[reg.type][reg] = idx defined.append(reg) LOG.debug("Updating usage of %s", reg) LOG.debug(list(lastused[reg.type].keys())) if reg in lastused[reg.type]: del lastused[reg.type][reg] lastused[reg.type][reg] = idx used.append(reg) LOG.debug(list(lastused[reg.type].keys())) # if self._min is True: # for reg in instr.uses(): # if reg in lastused[reg.type]: # del lastused[reg.type][reg] # lastused[reg.type][reg] = idx # used.append(reg) # for reg in instr.sets(): # if reg in lastused[reg.type]: # del lastused[reg.type][reg] # lastused[reg.type][reg] = idx # used.append(reg) fail = rregs.intersection(set(instr.sets())) if len(fail) > 0 and not self._relax: for reg in fail: if not instr.allows(reg): raise MicroprobeCodeGenerationError( "Instruction '%s' sets a reserved" " register but not allowed before:" "Register: %s. Reserved: %s" % (instr.name, reg.name, sorted([reg.name for reg in rregs]))) for reg in instr.uses(): if reg not in used and reg not in rregs: LOG.debug("Updating usage of %s", reg) lastused[reg.type][reg] = idx for reg in instr.sets(): if reg not in defined and reg not in rregs: LOG.debug("Updating definition of %s", reg) if reg in lastdefined[reg.type]: del lastdefined[reg.type][reg] lastdefined[reg.type][reg] = idx LOG.debug("Instruction: %s", instr) LOG.debug("Last used rank:") for rtype in lastused: for reg in lastused[rtype]: if lastused[rtype][reg] > 0 and reg.name.startswith( "GPR"): LOG.debug("%s: %d", reg.name, lastused[rtype][reg]) LOG.debug("Last defined rank:") for rtype in lastdefined: for reg in lastdefined[rtype]: if lastdefined[rtype][reg] > 0 and reg.name.startswith( "GPR"): LOG.debug("%s: %d", reg.name, lastdefined[rtype][reg]) idx = idx + 1
[docs] class CycleMinimalAllocationPass(microprobe.passes.Pass): """CycleMinimalAllocationPass pass. """
[docs] def __init__(self, size, reads, writes, value=None): """ :param size: :param reads: :param writes: :param value: (Default value = None) """ super(CycleMinimalAllocationPass, self).__init__() self._size = size # groups size self._reads = reads # max reads in a group to the same register self._writes = writes # max writes in a group to the same register self._rdwr = min(reads, writes) if value is not None: self._immediate = value else: self._immediate = "random"
def __call__(self, building_block, target): """ :param building_block: :param target: """ lastdefined = {} lastreaded = {} lastdefread = {} def reset_dictionaries(): """ Reset the local dependency tracking dictionaries """ allregs = list(target.registers.values()) allregs.sort(key=lambda x: int(re.findall(r'\d+', str(x))[0])) for reg in allregs: if reg.type not in lastreaded: lastreaded[reg.type] = OrderedDict() if reg.type not in lastdefined: lastdefined[reg.type] = OrderedDict() if reg.type not in lastdefread: lastdefread[reg.type] = OrderedDict() lastreaded[reg.type][reg] = 0 if reg not in building_block.context.reserved_registers: lastdefined[reg.type][reg] = 0 lastdefread[reg.type][reg] = 0 def remove_dict_key(reg): """ :param reg: """ if reg in lastdefined[reg.type]: del lastdefined[reg.type][reg] if reg in lastreaded[reg.type]: del lastreaded[reg.type][reg] if reg in lastdefread[reg.type]: del lastdefread[reg.type][reg] def update_dictionaries(instr): """ :param instr: """ for reg, arch_operands in zip(instr.operand_fields(), instr.architecture_type.operands()): operand, ins_input, output = arch_operands if reg is not None and not operand.immediate(): if ins_input and output: lastdefread[reg.type][reg] += 1 if reg in lastdefined[reg.type]: del lastdefined[reg.type][reg] if reg in lastreaded[reg.type]: del lastreaded[reg.type][reg] if lastdefread[reg.type][reg] == self._rdwr: remove_dict_key(reg) elif output: lastdefined[reg.type][reg] += 1 if reg in lastdefread[reg.type]: del lastdefread[reg.type][reg] if reg in lastreaded[reg.type]: del lastreaded[reg.type][reg] if lastdefined[reg.type][reg] == self._writes: remove_dict_key(reg) elif ins_input: if reg in lastreaded[reg.type]: lastreaded[reg.type][reg] += 1 if lastreaded[reg.type][reg] == self._reads: remove_dict_key(reg) if reg in lastdefined[reg.type]: del lastdefined[reg.type][reg] if reg in lastdefread[reg.type]: del lastdefread[reg.type][reg] def update_group(group): """ :param group: """ for instr in group: instroperands = [] oper_idx_io = 0 oper_idx_i = 0 oper_idx_o = 0 for operand, ins_input, output in instr.operands(): if operand.immediate(): if self._immediate != "random" and \ operand.check(self._immediate, safe=True): instroperands.append(self._immediate) elif self._immediate != "random" and \ not operand.check(self._immediate, safe=True): instroperands.append(operand.max) else: instroperands.append( operand.assembly(operand.random_value())) elif not operand.immediate(): regs = list(operand.values()) rtype = regs[0].type if ins_input and output: regs = [ elem for elem in lastdefread[rtype].keys() if elem in regs ] regs.sort(key=lambda x, ltype=rtype: lastdefread[ ltype][x], reverse=True) reg = regs[oper_idx_io] if reg in lastreaded[rtype]: del lastreaded[rtype][reg] if reg in lastdefined[rtype]: del lastdefined[rtype][reg] oper_idx_io += 1 elif ins_input: regs = [ elem for elem in lastreaded[rtype].keys() if elem in regs ] regs.sort(key=lambda x, ltype=rtype: lastreaded[ ltype][x], reverse=True) reg = regs[oper_idx_i] if reg in lastdefread[rtype]: del lastdefread[rtype][reg] if reg in lastdefined[rtype]: del lastdefined[rtype][reg] oper_idx_i += 1 elif output: regs = [ elem for elem in lastdefined[rtype].keys() if elem in regs ] regs.sort(key=lambda x, ltype=rtype: lastdefined[ ltype][x], reverse=True) reg = regs[oper_idx_o] if reg in lastdefread[rtype]: del lastdefread[rtype][reg] if reg in lastreaded[rtype]: del lastreaded[rtype][reg] oper_idx_o += 1 else: assert False, "Something wrong" instroperands.append(reg) instr.set_operands(instroperands) instr.check() update_dictionaries(instr) idx = 0 group = [] reset_dictionaries() for bbl in building_block.cfg.bbls: for instr in bbl.instrs: update_dictionaries(instr) group.append(instr) if idx % self._size == (self._size - 1): update_group(group) group = [] lastdefined = {} lastreaded = {} lastdefread = {} reset_dictionaries() idx = idx + 1 if len(group) > 0: update_group(group) group = [] lastdefined = {} lastreaded = {} lastdefread = {} reset_dictionaries() return []
[docs] class NoHazardsAllocationPass(microprobe.passes.Pass): """Avoid all possible data hazards: read after write (RAW), a true dependency write after read (WAR), an anti-dependency write after write (WAW), an output dependency """
[docs] def __init__(self): """ """ super(NoHazardsAllocationPass, self).__init__() self._description = "Perform register allocation avoiding RAW, WAR "\ "and WAR dependency hazards"
def __call__(self, building_block, target): """ :param building_block: :param dummy_target: """ rregs = set(building_block.context.reserved_registers) inputs = set() inputoutputs = set() outputs = set() usedrregs = set() LOG.debug("-" * 80) LOG.debug("BEGIN: Initial State") LOG.debug("Inputs: %s", inputs) LOG.debug("InputsOutputs: %s", inputoutputs) LOG.debug("Outputs: %s", outputs) LOG.debug("Reserved %s", rregs) LOG.debug("Used Reserved %s", usedrregs) LOG.debug("-" * 80) for bbl in building_block.cfg.bbls: allinstrs = bbl.instrs if bbl == building_block.cfg.bbls[-1]: allinstrs += building_block.fini for instr in allinstrs: for operand in instr.operands(): if operand.type.immediate: continue if isinstance(operand.value, Address): continue if operand.value is None and \ len(list(operand.type.values())) > 1: continue if len(list(operand.type.values())) == 1: operand.set_value(list(operand.type.values())[0]) skip = False for reg in operand.uses(): if operand.value in rregs: skip = True usedrregs.add(reg) break for reg in operand.sets(): if operand.value in rregs: skip = True usedrregs.add(reg) break if skip: continue if operand.is_input and operand.is_output: for reg in operand.sets(): inputoutputs.add(reg) for reg in operand.uses(): inputoutputs.add(reg) elif operand.is_input: for reg in operand.uses(): inputs.add(reg) elif operand.is_output: for reg in operand.sets(): outputs.add(reg) LOG.debug("-" * 80) LOG.debug("REGUSTER USAGE Before allocating") LOG.debug("Inputs: %s", inputs) LOG.debug("InputsOutputs: %s", inputoutputs) LOG.debug("Outputs: %s", outputs) LOG.debug("Reserved %s", rregs) LOG.debug("Used Reserved %s", usedrregs) LOG.debug("-" * 80) # # All reserved registers not used in the main loop # can be unreserved # rregs = rregs.intersection(usedrregs) LOG.debug("New Reserved %s", rregs) for reg in building_block.context.reserved_registers: if reg not in rregs: building_block.context.remove_reserved_registers([reg]) # Minimum allocation, make everybody happy with enough operand # values for bbl in building_block.cfg.bbls: for instr in bbl.instrs: for operand in instr.operands(): if operand.type.immediate: continue if operand.value is not None: continue regs = set(operand.type.values()) values = set() for reg in regs: if len(rregs.intersection( operand.type.access(reg))) > 0: continue values.add(reg) # values.difference_update(rregs) if operand.is_input and operand.is_output: # values.difference_update(inputoutputs) values.difference_update(inputs) values.difference_update(outputs) if len(values) > 0: for value in values: valid_val = True to_add = [] for val in operand.type.access(value): to_add.append(val) if val in inputs or val in outputs: valid_val = False if valid_val: inputoutputs.update(set(to_add)) elif operand.is_output: values.difference_update(inputoutputs) values.difference_update(inputs) values.difference_update(outputs) if len(values) > 0: for value in values: valid_val = True to_add = [] for val in operand.type.access(value): to_add.append(val) if (val in inputoutputs or val in inputs or val in outputs): valid_val = False if valid_val: outputs.update(set(to_add)) break elif operand.is_input: values.difference_update(inputoutputs) # values.difference_update(inputs) values.difference_update(outputs) if len(values) > 0: for value in values: valid_val = True to_add = [] for val in operand.type.access(value): to_add.append(val) if val in inputoutputs or val in outputs: valid_val = False if valid_val: inputs.update(set(to_add)) break else: raise MicroprobeCodeGenerationError( "Unknown operand type") inputs = sorted(list(inputs)) inputoutputs = sorted(list(inputoutputs)) outputs = sorted(list(outputs)) LOG.debug("-" * 80) LOG.debug("BEFORE ALLOCATING") LOG.debug("Inputs: %s", inputs) LOG.debug("InputsOutputs: %s", inputoutputs) LOG.debug("Outputs: %s", outputs) LOG.debug("Reserved %s", rregs) LOG.debug("-" * 80) for bbl in building_block.cfg.bbls: for instr in bbl.instrs: LOG.debug("#" * 80) LOG.debug("BEGIN INSTRUCTION %s", instr.name) LOG.debug("Inputs: %s", inputs) LOG.debug("InputsOutputs: %s", inputoutputs) LOG.debug("Outputs: %s", outputs) LOG.debug("Reserved %s", rregs) for operand in instr.operands(): LOG.debug("-" * 80) LOG.debug("BEGIN OPERAND") LOG.debug(operand) LOG.debug("Inputs: %s", inputs) LOG.debug("InputsOutputs: %s", inputoutputs) LOG.debug("Outputs: %s", outputs) LOG.debug("Reserved %s", rregs) LOG.debug("BEGIN OPERAND") if operand.type.immediate: LOG.debug("IMMEDIATE") continue if isinstance(operand.value, Address): LOG.debug("ADDRESS") continue if operand.type.constant: LOG.debug("CONSTANT") value = list(operand.type.values())[0] operand.set_value(value) # Update queues for value in operand.type.access(operand.value): if value in inputoutputs: inputoutputs.remove(value) inputoutputs.append(value) elif value in inputs: if operand.is_output: inputs.remove(value) inputoutputs.append(value) else: inputs.remove(value) inputs.append(value) elif value in outputs: if operand.is_input: outputs.remove(value) inputoutputs.append(value) else: outputs.remove(value) outputs.append(value) continue if operand.is_input and operand.is_output: queue = inputoutputs elif operand.is_input: queue = inputs elif operand.is_output: queue = outputs if len(queue) == 0: queue = inputoutputs if operand.value is not None: LOG.debug("ALREADY SET") LOG.debug("Set value: %s", operand.value) for value in operand.type.access(operand.value): if value not in rregs and value in queue: queue.remove(value) queue.append(value) elif value not in rregs and \ value in inputs: inputs.remove(value) inputs.append(value) elif value not in rregs and \ value in outputs: outputs.remove(value) outputs.append(value) elif value not in rregs and \ value in inputoutputs: inputoutputs.remove(value) inputoutputs.append(value) continue # get the operand type regt = list(operand.type.values())[0].type LOG.debug("NOT SET") LOG.debug("TYPE: %s", regt) regs = [ reg for reg in queue if reg.type == regt and reg in list(operand.type.values()) and reg not in rregs ] # if values are not found, fail-back to input output list if len(regs) == 0: LOG.debug("SWITCH!") LOG.debug(queue) LOG.debug(inputoutputs) queue = inputoutputs regs = [ reg for reg in queue if reg.type == regt and reg in list(operand.type.values()) ] # Check if all possible values are reserved, if so, just # pick one and add a comment if (set(operand.type.values()).issubset(rregs) and len(regs) == 0): operand.set_value(list(operand.type.values())[0]) LOG.debug("Operand forced to: %s", list(operand.type.values())[0]) instr.add_comment("Operand '%s' using reserved value" % operand) for reg in operand.sets(): if not instr.allows(reg): instr.add_allow_register([reg]) continue elif len(regs) == 0: LOG.debug("Operand forced to: %s", list(operand.type.values())[0]) operand.set_value(list(operand.type.values())[0]) instr.add_comment("Operand '%s' using first value" % operand) for reg in operand.sets(): if not instr.allows(reg): instr.add_allow_register([reg]) continue if len(regs) == 0: # Fall-back in case not regs are provided regs = list(operand.type.values()) LOG.debug("Operand set to: %s", regs[0]) operand.set_value(regs[0]) for value in operand.type.access(operand.value): if value in inputoutputs: inputoutputs.remove(value) if value in inputs: inputs.remove(value) if value in outputs: outputs.remove(value) queue.append(value) LOG.debug("*" * 80)
[docs] class FixRegistersPass(microprobe.passes.Pass): """DefaultRegisterAllocationPass pass. """
[docs] def __init__(self, forbid_writes=None, forbid_reads=None): """ """ super(FixRegistersPass, self).__init__() self._description = "Fix readed/written register" self._writes = forbid_writes self._reads = forbid_reads
def __call__(self, building_block, target): if self._writes is not None: self._writes = [ target.registers[regname] for regname in self._writes ] for bbl in building_block.cfg.bbls: for instr in bbl.instrs: for operand in instr.operands(): if self._writes is not None: sets = operand.sets() for elem in sets: if elem in self._writes: operand.unset_value() if self._reads is not None: uses = operand.sets() for elem in uses: if elem in self._reads: operand.unset_value()
[docs] class RandomAllocationPass(microprobe.passes.Pass): """RandomAllocationPass pass. """
[docs] def __init__(self): """ """ super(RandomAllocationPass, self).__init__() self._description = "Random Allocation of operands"
def __call__(self, building_block, dummy_target): for bbl in building_block.cfg.bbls: for instr in bbl.instrs: for operand in instr.operands(): operand.set_value(operand.type.random_value()) # TODO: This is a POWERPC hack (remove in future) if operand.type.name == "BO_Values": while operand.value in [17, 19, 21]: operand.set_value(operand.type.random_value())