# Copyright 2011-2021 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":mod:`microprobe.passes.register` module
"""
# Futures
from __future__ import absolute_import, print_function
# Built-in modules
import itertools
import random
import re
from typing import Union
# Third party modules
# Own modules
import microprobe.passes
import microprobe.utils.distrib
from microprobe.code.address import Address, MemoryValue
from microprobe.exceptions import \
MicroprobeCodeGenerationError, MicroprobeValueError
from microprobe.utils.logger import get_logger
from microprobe.utils.misc import OrderedDict
# Constants
LOG = get_logger(__name__)
__all__ = [
'DefaultRegisterAllocationPass', 'CycleMinimalAllocationPass',
'FixRegistersPass', 'NoHazardsAllocationPass', 'RandomAllocationPass'
]
# Functions
# Classes
[docs]
class DefaultRegisterAllocationPass(microprobe.passes.Pass):
"""DefaultRegisterAllocationPass pass.
"""
[docs]
def __init__(self,
minimize=False,
value=None,
dd: Union[int, float] = 0,
relax=False):
"""
:param minimize: (Default value = False)
:param value: (Default value = None)
:param dd: (Default value = 1)
"""
super(DefaultRegisterAllocationPass, self).__init__()
self._min = minimize
if self._min:
raise NotImplementedError("Feature changed need to be"
" reimplemented")
if isinstance(dd, int):
self._dd = lambda: dd # Dependency distance
elif isinstance(dd, float):
self._dd = microprobe.utils.distrib.discrete_average(dd)
elif callable(dd):
self._dd = dd
else:
raise MicroprobeValueError("Invalid parameter")
self._relax = relax
if value is not None:
self._immediate = value
if value == "zero":
self._immediate = 0
else:
self._immediate = "random"
self._description = "Default register allocation (required always to" \
" set any remaining operand). minimize=%s, " \
"value=%s, dd=%s" % (minimize, value, dd)
def __call__(self, building_block, target):
"""
:param building_block:
:param target:
"""
allregs = target.registers
lastdefined = {}
lastused = {}
rregs = set(building_block.context.reserved_registers)
# TODO: All this pass has to be reimplemented
for reg in allregs.values():
if reg in rregs:
# TODO: rregs can be used but only as input, now we discard
# them for input and output
continue
if reg.type not in lastused:
lastused[reg.type] = OrderedDict()
lastdefined[reg.type] = OrderedDict()
lastdefined[reg.type][reg] = 0
lastused[reg.type][reg] = 1
idx = 1
for bbl in building_block.cfg.bbls:
for instr in bbl.instrs:
dependency_ok = False
used = []
defined = []
distance = self._dd()
for operand in instr.operands():
if operand.value is not None:
# operand already set
LOG.debug("Operand already set")
pass
elif operand.type.immediate:
LOG.debug("Operand immediate")
if self._immediate != "random":
try:
svalue = self._immediate
while callable(svalue):
svalue = svalue()
operand.set_value(svalue)
except MicroprobeValueError:
LOG.warning(
"Operand '%s' in instruction '%s' "
"not modeled properly. Tried to "
"'%s' value ...", operand.type, instr.name,
svalue)
value = list(operand.type.values())[0]
operand.set_value(value)
LOG.warning("Operand set to: '%s'", value)
else:
operand.set_value(operand.type.random_value())
elif operand.type.address_relative:
LOG.warning(
"Operand '%s' in instruction '%s' "
"not modeled properly", operand.type, instr.name)
operand.set_value(list(operand.type.values())[0])
else:
LOG.debug("Operand is register")
if operand.is_input and distance > 0 and \
not dependency_ok and idx > distance:
LOG.debug("Setting dependency distance")
regs = list(operand.type.values())
if len(regs) == 1:
# There is not option
valid_values = regs
else:
valid_values = []
for reg in regs:
if len(
rregs.intersection(
operand.type.access(
reg))) == 0:
valid_values.append(reg)
if (len(valid_values) == 0 and operand.is_input
and not operand.is_output):
# Try to use reserved values
# if the operand is read only
valid_values = list(operand.type.values())
LOG.debug("Current idx: %d", idx)
LOG.debug("Requested idx: %d", idx - distance)
reg = microprobe.utils.distrib.sort_by_distance(
valid_values, lastdefined[regs[0].type],
lastused[regs[0].type], idx - distance, instr,
idx)
LOG.debug("%s selected", reg)
if reg in lastdefined[regs[0].type]:
LOG.debug("Last defined: %s",
lastdefined[regs[0].type][reg])
if reg in lastused[regs[0].type]:
LOG.debug("Last used: %s",
lastused[regs[0].type][reg])
dependency_ok = True
else:
LOG.debug("No need for distance")
regs = list(operand.type.values())
valid_values = []
if len(regs) == 1:
# There is not option
valid_values = regs
else:
for reg in regs:
if len(
rregs.intersection(
operand.type.access(
reg))) == 0:
valid_values.append(reg)
if (len(valid_values) == 0 and operand.is_input
and not operand.is_output):
# Try to use reserved values
# if the operand is reading only
valid_values = list(operand.type.values())
if len(valid_values) == 0:
LOG.critical("Instruction: %s", instr)
LOG.critical("Operand: %s", operand)
LOG.critical("Possible all values: %s",
sorted(operand.type.values()))
LOG.critical("Possible values: %s",
sorted(regs))
LOG.critical("Reserved values: %s",
sorted(rregs))
raise MicroprobeCodeGenerationError(
"Unable to find proper operand"
" values for register allocation.")
if regs[0].type not in lastused:
raise MicroprobeCodeGenerationError(
"Unable to find proper operand"
" values for register allocation."
" No registers of type '%s' "
"available." % regs[0].type)
reg = microprobe.utils.distrib.sort_by_usage(
valid_values, lastused[regs[0].type],
lastdefined[regs[0].type])
operand.set_value(reg)
if self._min is False:
for reg in operand.uses():
LOG.debug("Updating usage of %s", reg)
LOG.debug(list(lastused[reg.type].keys()))
if reg in lastused[reg.type]:
del lastused[reg.type][reg]
lastused[reg.type][reg] = idx
used.append(reg)
LOG.debug(list(lastused[reg.type].keys()))
if operand.is_output:
for reg in operand.sets():
LOG.debug("Updating definition of %s", reg)
if reg in lastdefined[reg.type]:
del lastdefined[reg.type][reg]
lastdefined[reg.type][reg] = idx
defined.append(reg)
LOG.debug("Updating usage of %s", reg)
LOG.debug(list(lastused[reg.type].keys()))
if reg in lastused[reg.type]:
del lastused[reg.type][reg]
lastused[reg.type][reg] = idx
used.append(reg)
LOG.debug(list(lastused[reg.type].keys()))
# if self._min is True:
# for reg in instr.uses():
# if reg in lastused[reg.type]:
# del lastused[reg.type][reg]
# lastused[reg.type][reg] = idx
# used.append(reg)
# for reg in instr.sets():
# if reg in lastused[reg.type]:
# del lastused[reg.type][reg]
# lastused[reg.type][reg] = idx
# used.append(reg)
fail = rregs.intersection(set(instr.sets()))
if len(fail) > 0 and not self._relax:
for reg in fail:
if not instr.allows(reg):
raise MicroprobeCodeGenerationError(
"Instruction '%s' sets a reserved"
" register but not allowed before:"
"Register: %s. Reserved: %s" %
(instr.name, reg.name,
sorted([reg.name for reg in rregs])))
for reg in instr.uses():
if reg not in used and reg not in rregs:
LOG.debug("Updating usage of %s", reg)
lastused[reg.type][reg] = idx
for reg in instr.sets():
if reg not in defined and reg not in rregs:
LOG.debug("Updating definition of %s", reg)
if reg in lastdefined[reg.type]:
del lastdefined[reg.type][reg]
lastdefined[reg.type][reg] = idx
LOG.debug("Instruction: %s", instr)
LOG.debug("Last used rank:")
for rtype in lastused:
for reg in lastused[rtype]:
if lastused[rtype][reg] > 0 and reg.name.startswith(
"GPR"):
LOG.debug("%s: %d", reg.name, lastused[rtype][reg])
LOG.debug("Last defined rank:")
for rtype in lastdefined:
for reg in lastdefined[rtype]:
if lastdefined[rtype][reg] > 0 and reg.name.startswith(
"GPR"):
LOG.debug("%s: %d", reg.name,
lastdefined[rtype][reg])
idx = idx + 1
[docs]
class CycleMinimalAllocationPass(microprobe.passes.Pass):
"""CycleMinimalAllocationPass pass.
"""
[docs]
def __init__(self, size, reads, writes, value=None):
"""
:param size:
:param reads:
:param writes:
:param value: (Default value = None)
"""
super(CycleMinimalAllocationPass, self).__init__()
self._size = size # groups size
self._reads = reads # max reads in a group to the same register
self._writes = writes # max writes in a group to the same register
self._rdwr = min(reads, writes)
if value is not None:
self._immediate = value
else:
self._immediate = "random"
def __call__(self, building_block, target):
"""
:param building_block:
:param target:
"""
lastdefined = {}
lastreaded = {}
lastdefread = {}
def reset_dictionaries():
""" Reset the local dependency tracking dictionaries """
allregs = list(target.registers.values())
allregs.sort(key=lambda x: int(re.findall(r'\d+', str(x))[0]))
for reg in allregs:
if reg.type not in lastreaded:
lastreaded[reg.type] = OrderedDict()
if reg.type not in lastdefined:
lastdefined[reg.type] = OrderedDict()
if reg.type not in lastdefread:
lastdefread[reg.type] = OrderedDict()
lastreaded[reg.type][reg] = 0
if reg not in building_block.context.reserved_registers:
lastdefined[reg.type][reg] = 0
lastdefread[reg.type][reg] = 0
def remove_dict_key(reg):
"""
:param reg:
"""
if reg in lastdefined[reg.type]:
del lastdefined[reg.type][reg]
if reg in lastreaded[reg.type]:
del lastreaded[reg.type][reg]
if reg in lastdefread[reg.type]:
del lastdefread[reg.type][reg]
def update_dictionaries(instr):
"""
:param instr:
"""
for reg, arch_operands in zip(instr.operand_fields(),
instr.architecture_type.operands()):
operand, ins_input, output = arch_operands
if reg is not None and not operand.immediate():
if ins_input and output:
lastdefread[reg.type][reg] += 1
if reg in lastdefined[reg.type]:
del lastdefined[reg.type][reg]
if reg in lastreaded[reg.type]:
del lastreaded[reg.type][reg]
if lastdefread[reg.type][reg] == self._rdwr:
remove_dict_key(reg)
elif output:
lastdefined[reg.type][reg] += 1
if reg in lastdefread[reg.type]:
del lastdefread[reg.type][reg]
if reg in lastreaded[reg.type]:
del lastreaded[reg.type][reg]
if lastdefined[reg.type][reg] == self._writes:
remove_dict_key(reg)
elif ins_input:
if reg in lastreaded[reg.type]:
lastreaded[reg.type][reg] += 1
if lastreaded[reg.type][reg] == self._reads:
remove_dict_key(reg)
if reg in lastdefined[reg.type]:
del lastdefined[reg.type][reg]
if reg in lastdefread[reg.type]:
del lastdefread[reg.type][reg]
def update_group(group):
"""
:param group:
"""
for instr in group:
instroperands = []
oper_idx_io = 0
oper_idx_i = 0
oper_idx_o = 0
for operand, ins_input, output in instr.operands():
if operand.immediate():
if self._immediate != "random" and \
operand.check(self._immediate, safe=True):
instroperands.append(self._immediate)
elif self._immediate != "random" and \
not operand.check(self._immediate, safe=True):
instroperands.append(operand.max)
else:
instroperands.append(
operand.assembly(operand.random_value()))
elif not operand.immediate():
regs = list(operand.values())
rtype = regs[0].type
if ins_input and output:
regs = [
elem for elem in lastdefread[rtype].keys()
if elem in regs
]
regs.sort(key=lambda x, ltype=rtype: lastdefread[
ltype][x],
reverse=True)
reg = regs[oper_idx_io]
if reg in lastreaded[rtype]:
del lastreaded[rtype][reg]
if reg in lastdefined[rtype]:
del lastdefined[rtype][reg]
oper_idx_io += 1
elif ins_input:
regs = [
elem for elem in lastreaded[rtype].keys()
if elem in regs
]
regs.sort(key=lambda x, ltype=rtype: lastreaded[
ltype][x],
reverse=True)
reg = regs[oper_idx_i]
if reg in lastdefread[rtype]:
del lastdefread[rtype][reg]
if reg in lastdefined[rtype]:
del lastdefined[rtype][reg]
oper_idx_i += 1
elif output:
regs = [
elem for elem in lastdefined[rtype].keys()
if elem in regs
]
regs.sort(key=lambda x, ltype=rtype: lastdefined[
ltype][x],
reverse=True)
reg = regs[oper_idx_o]
if reg in lastdefread[rtype]:
del lastdefread[rtype][reg]
if reg in lastreaded[rtype]:
del lastreaded[rtype][reg]
oper_idx_o += 1
else:
assert False, "Something wrong"
instroperands.append(reg)
instr.set_operands(instroperands)
instr.check()
update_dictionaries(instr)
idx = 0
group = []
reset_dictionaries()
for bbl in building_block.cfg.bbls:
for instr in bbl.instrs:
update_dictionaries(instr)
group.append(instr)
if idx % self._size == (self._size - 1):
update_group(group)
group = []
lastdefined = {}
lastreaded = {}
lastdefread = {}
reset_dictionaries()
idx = idx + 1
if len(group) > 0:
update_group(group)
group = []
lastdefined = {}
lastreaded = {}
lastdefread = {}
reset_dictionaries()
return []
[docs]
class NoHazardsAllocationPass(microprobe.passes.Pass):
"""Avoid all possible data hazards:
read after write (RAW), a true dependency
write after read (WAR), an anti-dependency
write after write (WAW), an output dependency
"""
[docs]
def __init__(self):
""" """
super(NoHazardsAllocationPass, self).__init__()
self._description = "Perform register allocation avoiding RAW, WAR "\
"and WAR dependency hazards"
def __call__(self, building_block, target):
"""
:param building_block:
:param dummy_target:
"""
rregs = set(building_block.context.reserved_registers)
inputs = set()
inputoutputs = set()
outputs = set()
usedrregs = set()
LOG.debug("-" * 80)
LOG.debug("BEGIN: Initial State")
LOG.debug("Inputs: %s", inputs)
LOG.debug("InputsOutputs: %s", inputoutputs)
LOG.debug("Outputs: %s", outputs)
LOG.debug("Reserved %s", rregs)
LOG.debug("Used Reserved %s", usedrregs)
LOG.debug("-" * 80)
for bbl in building_block.cfg.bbls:
allinstrs = bbl.instrs
if bbl == building_block.cfg.bbls[-1]:
allinstrs += building_block.fini
for instr in allinstrs:
for operand in instr.operands():
if operand.type.immediate:
continue
if isinstance(operand.value, Address):
continue
if operand.value is None and \
len(list(operand.type.values())) > 1:
continue
if len(list(operand.type.values())) == 1:
operand.set_value(list(operand.type.values())[0])
skip = False
for reg in operand.uses():
if operand.value in rregs:
skip = True
usedrregs.add(reg)
break
for reg in operand.sets():
if operand.value in rregs:
skip = True
usedrregs.add(reg)
break
if skip:
continue
if operand.is_input and operand.is_output:
for reg in operand.sets():
inputoutputs.add(reg)
for reg in operand.uses():
inputoutputs.add(reg)
elif operand.is_input:
for reg in operand.uses():
inputs.add(reg)
elif operand.is_output:
for reg in operand.sets():
outputs.add(reg)
LOG.debug("-" * 80)
LOG.debug("REGUSTER USAGE Before allocating")
LOG.debug("Inputs: %s", inputs)
LOG.debug("InputsOutputs: %s", inputoutputs)
LOG.debug("Outputs: %s", outputs)
LOG.debug("Reserved %s", rregs)
LOG.debug("Used Reserved %s", usedrregs)
LOG.debug("-" * 80)
#
# All reserved registers not used in the main loop
# can be unreserved
#
rregs = rregs.intersection(usedrregs)
LOG.debug("New Reserved %s", rregs)
for reg in building_block.context.reserved_registers:
if reg not in rregs:
building_block.context.remove_reserved_registers([reg])
# Minimum allocation, make everybody happy with enough operand
# values
for bbl in building_block.cfg.bbls:
for instr in bbl.instrs:
for operand in instr.operands():
if operand.type.immediate:
continue
if operand.value is not None:
continue
regs = set(operand.type.values())
values = set()
for reg in regs:
if len(rregs.intersection(
operand.type.access(reg))) > 0:
continue
values.add(reg)
# values.difference_update(rregs)
if operand.is_input and operand.is_output:
# values.difference_update(inputoutputs)
values.difference_update(inputs)
values.difference_update(outputs)
if len(values) > 0:
for value in values:
valid_val = True
to_add = []
for val in operand.type.access(value):
to_add.append(val)
if val in inputs or val in outputs:
valid_val = False
if valid_val:
inputoutputs.update(set(to_add))
elif operand.is_output:
values.difference_update(inputoutputs)
values.difference_update(inputs)
values.difference_update(outputs)
if len(values) > 0:
for value in values:
valid_val = True
to_add = []
for val in operand.type.access(value):
to_add.append(val)
if (val in inputoutputs or val in inputs
or val in outputs):
valid_val = False
if valid_val:
outputs.update(set(to_add))
break
elif operand.is_input:
values.difference_update(inputoutputs)
# values.difference_update(inputs)
values.difference_update(outputs)
if len(values) > 0:
for value in values:
valid_val = True
to_add = []
for val in operand.type.access(value):
to_add.append(val)
if val in inputoutputs or val in outputs:
valid_val = False
if valid_val:
inputs.update(set(to_add))
break
else:
raise MicroprobeCodeGenerationError(
"Unknown operand type")
inputs = sorted(list(inputs))
inputoutputs = sorted(list(inputoutputs))
outputs = sorted(list(outputs))
LOG.debug("-" * 80)
LOG.debug("BEFORE ALLOCATING")
LOG.debug("Inputs: %s", inputs)
LOG.debug("InputsOutputs: %s", inputoutputs)
LOG.debug("Outputs: %s", outputs)
LOG.debug("Reserved %s", rregs)
LOG.debug("-" * 80)
for bbl in building_block.cfg.bbls:
for instr in bbl.instrs:
LOG.debug("#" * 80)
LOG.debug("BEGIN INSTRUCTION %s", instr.name)
LOG.debug("Inputs: %s", inputs)
LOG.debug("InputsOutputs: %s", inputoutputs)
LOG.debug("Outputs: %s", outputs)
LOG.debug("Reserved %s", rregs)
for operand in instr.operands():
LOG.debug("-" * 80)
LOG.debug("BEGIN OPERAND")
LOG.debug(operand)
LOG.debug("Inputs: %s", inputs)
LOG.debug("InputsOutputs: %s", inputoutputs)
LOG.debug("Outputs: %s", outputs)
LOG.debug("Reserved %s", rregs)
LOG.debug("BEGIN OPERAND")
if operand.type.immediate:
LOG.debug("IMMEDIATE")
continue
if isinstance(operand.value, Address):
LOG.debug("ADDRESS")
continue
if operand.type.constant:
LOG.debug("CONSTANT")
value = list(operand.type.values())[0]
operand.set_value(value)
# Update queues
for value in operand.type.access(operand.value):
if value in inputoutputs:
inputoutputs.remove(value)
inputoutputs.append(value)
elif value in inputs:
if operand.is_output:
inputs.remove(value)
inputoutputs.append(value)
else:
inputs.remove(value)
inputs.append(value)
elif value in outputs:
if operand.is_input:
outputs.remove(value)
inputoutputs.append(value)
else:
outputs.remove(value)
outputs.append(value)
continue
if operand.is_input and operand.is_output:
queue = inputoutputs
elif operand.is_input:
queue = inputs
elif operand.is_output:
queue = outputs
if len(queue) == 0:
queue = inputoutputs
if operand.value is not None:
LOG.debug("ALREADY SET")
LOG.debug("Set value: %s", operand.value)
for value in operand.type.access(operand.value):
if value not in rregs and value in queue:
queue.remove(value)
queue.append(value)
elif value not in rregs and \
value in inputs:
inputs.remove(value)
inputs.append(value)
elif value not in rregs and \
value in outputs:
outputs.remove(value)
outputs.append(value)
elif value not in rregs and \
value in inputoutputs:
inputoutputs.remove(value)
inputoutputs.append(value)
continue
# get the operand type
regt = list(operand.type.values())[0].type
LOG.debug("NOT SET")
LOG.debug("TYPE: %s", regt)
regs = [
reg for reg in queue if reg.type == regt and reg in
list(operand.type.values()) and reg not in rregs
]
# if values are not found, fail-back to input output list
if len(regs) == 0:
LOG.debug("SWITCH!")
LOG.debug(queue)
LOG.debug(inputoutputs)
queue = inputoutputs
regs = [
reg for reg in queue if reg.type == regt
and reg in list(operand.type.values())
]
# Check if all possible values are reserved, if so, just
# pick one and add a comment
if (set(operand.type.values()).issubset(rregs)
and len(regs) == 0):
operand.set_value(list(operand.type.values())[0])
LOG.debug("Operand forced to: %s",
list(operand.type.values())[0])
instr.add_comment("Operand '%s' using reserved value" %
operand)
for reg in operand.sets():
if not instr.allows(reg):
instr.add_allow_register([reg])
continue
elif len(regs) == 0:
LOG.debug("Operand forced to: %s",
list(operand.type.values())[0])
operand.set_value(list(operand.type.values())[0])
instr.add_comment("Operand '%s' using first value" %
operand)
for reg in operand.sets():
if not instr.allows(reg):
instr.add_allow_register([reg])
continue
if len(regs) == 0:
# Fall-back in case not regs are provided
regs = list(operand.type.values())
LOG.debug("Operand set to: %s", regs[0])
operand.set_value(regs[0])
for value in operand.type.access(operand.value):
if value in inputoutputs:
inputoutputs.remove(value)
if value in inputs:
inputs.remove(value)
if value in outputs:
outputs.remove(value)
queue.append(value)
LOG.debug("*" * 80)
[docs]
class FixRegistersPass(microprobe.passes.Pass):
"""DefaultRegisterAllocationPass pass.
"""
[docs]
def __init__(self, forbid_writes=None, forbid_reads=None):
""" """
super(FixRegistersPass, self).__init__()
self._description = "Fix readed/written register"
self._writes = forbid_writes
self._reads = forbid_reads
def __call__(self, building_block, target):
if self._writes is not None:
self._writes = [
target.registers[regname] for regname in self._writes
]
for bbl in building_block.cfg.bbls:
for instr in bbl.instrs:
for operand in instr.operands():
if self._writes is not None:
sets = operand.sets()
for elem in sets:
if elem in self._writes:
operand.unset_value()
if self._reads is not None:
uses = operand.sets()
for elem in uses:
if elem in self._reads:
operand.unset_value()
[docs]
class RandomAllocationPass(microprobe.passes.Pass):
"""RandomAllocationPass pass.
"""
[docs]
def __init__(self):
""" """
super(RandomAllocationPass, self).__init__()
self._description = "Random Allocation of operands"
def __call__(self, building_block, dummy_target):
for bbl in building_block.cfg.bbls:
for instr in bbl.instrs:
for operand in instr.operands():
operand.set_value(operand.type.random_value())
# TODO: This is a POWERPC hack (remove in future)
if operand.type.name == "BO_Values":
while operand.value in [17, 19, 21]:
operand.set_value(operand.type.random_value())