Source code for qsynthesis.tables.base
# built-in libs
from __future__ import annotations
from pathlib import Path
from enum import IntEnum
import array
import hashlib
import threading
from collections import Counter
from time import time, sleep
import ctypes
import logging
# third-party libs
import psutil
# qsynthesis deps
from qsynthesis.grammar import TritonGrammar
from qsynthesis.tritonast import TritonAst
from qsynthesis.types import AstNode, Hash, Optional, List, Dict, Union, Tuple, Iterable, Input, Output, BitSize, Any, \
Generator
logger = logging.getLogger("qsynthesis")
class _EvalCtx(object):
"""
Small debugging Triton evaluation context. It is used when manipulating
tables in a standalone manner. It enables obtaining TritonAst out of
the databqse entries.
"""
def __init__(self, grammar):
from triton import TritonContext, ARCH, AST_REPRESENTATION
# Create the context
self.ctx = TritonContext(ARCH.X86_64)
self.ctx.setAstRepresentationMode(AST_REPRESENTATION.PYTHON)
self.ast = self.ctx.getAstContext()
# Create symbolic variables for grammar variables
self.symvars = {}
self.vars = {}
for v, sz in grammar.vars_dict.items():
sym_v = self.ctx.newSymbolicVariable(sz, v)
self.symvars[v] = sym_v
self.vars[v] = self.ast.variable(sym_v)
# Create mapping to triton operators
self.tops = {x: getattr(self.ast, x) for x in dir(self.ast) if not x.startswith("__")}
def eval_str(self, s: str) -> AstNode:
"""Eval the string expression to create an AstNode object"""
e = eval(s, self.tops, self.vars)
bv_size = list(self.vars.values())[0].getBitvectorSize() # Assume all vars are of same size
if isinstance(e, int): # In case the expression was in fact an int
return self.ast.bv(e, bv_size)
else:
return e
def set_symvar_values(self, args: Input) -> None:
for v_name, value in args.items():
self.ctx.setConcreteVariableValue(self.symvars[v_name], value)
[docs]
class InputOutputOracle:
"""
Base Lookup table class. Specify the interface that child class have to
implement to be interoperable with other the synthesizer.
"""
def __init__(self, gr: TritonGrammar, inputs: List[Input], f_name: Union[Path, str] = ""):
"""
Constructor making a I/O oracle from a grammar a set of inputs and a hash type.
:param gr: triton grammar
:param inputs: List of inputs
:param f_name: file name of the table (when being loaded)
"""
self._name = Path(f_name)
self.grammar = gr
self._bitsize = self.grammar.size
self.expr_cache = {}
self.lookup_count = 0
self.lookup_found = 0
self.cache_hit = 0
self._ectx = None
# generation related fields
self.watchdog = None
self.max_mem = 0
self.stop = False
self.inputs = inputs
@property
def size(self) -> int:
"""Size of the table (number of entries)
:rtype: int
"""
raise NotImplementedError("Should be implemented by child class")
def _get_item(self, h: Hash) -> Optional[str]:
"""
From a given hash return the associated expression string if
found in the lookup table.
:param h: hash of the item to get
:returns: raw expression string if found
"""
raise NotImplementedError("Should be implemented by child class")
[docs]
def is_expr_compatible(self, expr: TritonAst) -> bool:
"""
Check the compatibility of the given expression with the table.
The function checks sizes of expr variables against the one of
its own grammar.
:param expr: TritonAst expression to check
:return: True if the table can decide on this expression
"""
e_vars = Counter(x.getBitSize() for x in expr.symvars)
e_table = Counter(self.grammar.vars_dict.values())
for sz, count in e_vars.items():
if sz in e_table:
if count > e_table[sz]:
return False
else:
return False
return True
[docs]
def lookup(self, outputs: List[Output], *args, use_cache: bool = True) -> Optional[TritonAst]:
"""
Perform a lookup in the table with a given set of outputs corresponding
to the evaluation of an AST against the Input of this exact same table.
If an entry is found a TritonAst is created and returned.
:param outputs: list of output result of evaluating an ast against the inputs of this table
:type: List[:py:obj:`qsynthesis.types.Output`]
:param args: args forwarded to grammar and ultimately to the tritonAst in charge of building a new TritonAst
:param use_cache: Boolean enabling caching the the hash of outputs. A second call if the same outputs
(which is common) will not trigger a lookup in the database
:returns: optional TritonAst corresponding of the expression found in the table
"""
self.lookup_count += 1
h = self.hash(outputs)
if h in self.expr_cache and use_cache:
self.cache_hit += 1
return self.expr_cache[h]
else:
v = self._get_item(h)
if v:
self.lookup_found += 1
try:
e = self.grammar.str_to_expr(v, *args)
self.expr_cache[h] = e
return e
except NameError:
return None
except TypeError:
return None
else:
return None
[docs]
def lookup_hash(self, h: Hash) -> Optional[str]:
"""
Raw lookup for a given key in database.
:param h: hash key to look for in database
:type h: :py:obj:`qsynthesis.types.Hash`
:returns: string of the expression if found
:rtype: Optional[str]
"""
return self._get_item(h)
@property
def is_writable(self) -> bool:
""" Whether the table enable being written (with new expressions)
:rtype: bool
"""
return False
@property
def name(self) -> str:
""" Name of the table
:rtype: str
"""
return str(self._name)
@property
def bitsize(self) -> BitSize:
""" Size of expression in bit
:rtype: :py:obj:`qsynthesis.types.BitSize`
"""
return self._bitsize
@property
def var_number(self) -> int:
""" Maximum number of variables contained in the table
:rtype: int
"""
return len(self.grammar.vars)
@property
def operator_number(self) -> int:
""" Number of operators used in this table
:rtype: int
"""
return len(self.grammar.ops)
@property
def input_number(self) -> int:
""" Number of inputs used in this table
:rtype: int
"""
return len(self.inputs)
[docs]
def hash(self, outs: List[Output]) -> Hash:
"""
Main hashing method that convert outputs to an hash.
The hash used is MD5. Note that hashed values are systematically
casted in an array of 64bit integers.
:param outs: list of outputs to hash
:type outs: List[:py:obj:`qsynthesis.types.Output`]
:returns: Hash type (bytes, int ..) of the outputs
:rtype: :py:obj:`qsynthesis.types.Hash`
"""
a = array.array('Q', outs)
h = hashlib.md5(a.tobytes())
return h.digest()
def __iter__(self) -> Iterable[Tuple[Hash, str]]:
""" Iterator of all the entries as an iterator of pair, hash, expression as string
:rtype: Iterable[Tuple[:py:obj:`qsynthesis.types.Hash, str]]`
"""
raise NotImplementedError("Should be implemented by child class")
def _get_expr(self, expr: str) -> AstNode:
"""
Utility function that returns a TritonAst from a given expression string.
A TritonContext local to the table is created to enable generating such ASTs.
:param expr: Expression
:returns: TritonAst resulting of the parsing of s
"""
if self._ectx is None:
self._ectx = _EvalCtx(self.grammar)
return self._ectx.eval_str(expr)
def _set_input_lcontext(self, i: Union[int, Input]) -> None:
"""
Set the given concrete values of variables in the local TritonContext.
The parameter is either the ith input of the table, or directly an Input
given a valuation for each variables. This function must be called before
performing any evaluation of an AST.
:param i: index of the input, or Input object (dict)
:returns: None
"""
if self._ectx is None:
self._ectx = _EvalCtx(self.grammar)
self._ectx.set_symvar_values(self.inputs[i] if isinstance(i, int) else i)
def _eval_expr_inputs(self, expr: AstNode) -> List[Output]:
"""
Evaluate a given Triton AstNode object on all inputs of the
table. The result is a list of Output values.
:param expr: Triton AstNode to evaluate
:type expr: :py:obj:`qsynthesis.types.AstNode`
:returns: list of output values (ready to be hashed)
:rtype: List[:py:obj:`qsynthesis.types.Output`]
"""
outs = []
for i in range(len(self.inputs)):
self._set_input_lcontext(i)
outs.append(expr.evaluate())
return outs
def _watchdog_worker(self, threshold: Union[float, int]) -> None:
"""
Function where the memory watchdog thread is running. This function
allows interrupting table generation when it happens to fill the
given threshold of RAM.
:param threshold: percentage of RAM load that triggers the stop of generation
"""
while not self.stop:
sleep(2)
mem = psutil.virtual_memory()
self.max_mem = max(mem.used, self.max_mem)
if mem.percent >= threshold:
logger.warning(f"Threshold reached: {mem.percent}%")
self.stop = True # Should stop self and also main thread
@staticmethod
def _try_linearize(s: str, symbols: Dict[str, object]) -> str:
"""
Try applying sympy to linearize ``s`` with the variable symbols
``symbols``. If any exception is raised in between to expression
string is returned unchanged.
:param s: expression string to linearize
:param symbols: dictionnary of variables names to sympy symbol objects
.. warning:: This function requires sympy to be installed !
"""
import sympy
try:
lin = eval(s, symbols)
if isinstance(lin, sympy.boolalg.BooleanFalse):
logger.error(f"[linearization] expression {s} False")
logger.debug(f"[linearization] expression linearized {s} => {lin}")
return str(lin).replace(" ", "")
except TypeError:
return s
except AttributeError as _:
return s
@staticmethod
def _to_signed(value: int) -> int:
return ctypes.c_longlong(value).value
@staticmethod
def _to_unsigned(value: int) -> int:
return ctypes.c_ulonglong(value).value
@staticmethod
def _is_constant(v1: str) -> bool:
try:
int(v1)
return True
except ValueError:
return False
@staticmethod
def _custom_permutations(l: List[Any]) -> Generator[Tuple[bool, Any, Any], None, None]:
"""
Custom generator generating all the possible tuples from a list. But instead
of iterating item i with all others 0..n, iterates i with all the previous 0..i.
It generates a somewhat sorted generated that ensure pairs of items appearing
first in the list will be yielded before.
:param l: list of any item
:returns: genreator of tuples generating all possibles pairs
"""
for i in range(len(l)):
for j in range(0, i):
yield False, l[i], l[j]
yield False, l[j], l[i]
yield True, l[i], l[i]
[docs]
def generate(self,
bitsize: int,
constants: List[int] = [],
do_watch: bool = False,
watchdog_threshold: Union[int, float] = 90,
linearize: bool = False,
do_use_blacklist: bool = False,
limit: int = 0) -> None:
"""
Generate a new lookup table from scratch with the variables and operators
set in the constructor of the table.
:param bitsize: Bitsize of expressions to generate
:param constants: List of constants to use in the generation
:param do_watch: Enable RAM watching thread to monitor memory
:param watchdog_threshold: threshold to be sent to the memory watchdog
:param linearize: whether or not to apply linearization on expressions
:param do_use_blacklist: enable blacklist mechanism on commutative operators. Slower but less memory consuming
:param limit: Maximum number of entries to generate
:returns: None
"""
if do_watch:
self.watchdog = threading.Thread(target=self._watchdog_worker, args=[watchdog_threshold], daemon=True)
logger.debug("Start watchdog")
self.watchdog.start()
if linearize:
logger.info("Linearization enabled")
import sympy
symbols = {x: sympy.symbols(x) for x in self.grammar.vars}
t0 = time()
from qsynthesis.grammar import jitting # Import it locally to make sure pydffi is not mandatory
CU = jitting.make_compilation_unit(bitsize)
N = self.input_number
ArTy = jitting.get_native_array_type(bitsize, N)
# Initialize worklist with variables
worklist = [(ArTy(), k) for k in self.grammar.vars]
for i, inp in enumerate(self.inputs):
for v, k in worklist:
v[i] = inp[k]
# Initialize worklist with constants
csts = [(ArTy(), str(c)) for c in constants]
for (ar, c) in csts:
jitting.init_array_cst(ar, int(c), N, bitsize)
worklist.extend(csts)
# initialize set of hash
hash_set = set(self.hash(x[0]) for x in worklist)
ops = sorted(self.grammar.non_terminal_operators, key=lambda x: x.arity == 1) # sort operators to iterate on unary first
cur_depth = 2
blacklist = set()
item_count = len(worklist) # total number of expressions
try:
while cur_depth > 0:
# Start a new depth
n_items = len(worklist) # number of items to process at a given depth
t = time() - t0
print(f"Depth {cur_depth} (size:{n_items}) (Time:{int(t/60)}m{t%60:.5f}s)")
c = 0
for i, (same, (vals1, name1), (vals2, name2)) in enumerate(self._custom_permutations(worklist)):
if same:
c += 1
print(f"process: {(c*100)/n_items:.2f}%\r", end="")
if 0 < limit <= item_count:
self.stop = True
if self.stop:
logger.warning("Threshold reached, generation interrupted")
raise KeyboardInterrupt()
# Check it here once then iterate operators
name1_cst, name2_cst = self._is_constant(name1), self._is_constant(name2)
is_both_constant = name1_cst & name2_cst
for op, op_eval in zip(ops, [jitting.get_op_eval_array(CU, x) for x in ops]): # Iterate over all operators
if op.arity == 1:
new_vals = ArTy()
op_eval(new_vals, vals1, N)
h = self.hash(new_vals)
if h not in hash_set:
if name1_cst:
fmt = str(self._to_signed(new_vals[0])) # any value is the new constant value
else:
fmt = f"{op.symbol}({name1})" if len(name1) > 1 else f"{op.symbol}{name1}"
fmt = self._try_linearize(fmt, symbols) if linearize else fmt
logger.debug(f"[add] {fmt: <20} {h}")
hash_set.add(h)
item_count += 1
worklist.append((new_vals, fmt)) # add it in worklist if not already in LUT
else:
logger.debug(f"[drop] {op.symbol}{name1} ")
else: # arity is 2
# for identity (a op a) ignore it if the result is known to be 0 or a
if same and (op.id_eq or op.id_zero):
continue
sn1 = f'{name1}' if len(name1) == 1 else f'({name1})'
sn2 = f'{name2}' if len(name2) == 1 else f'({name2})'
fmt = f"{op.symbol}({name1},{name2})" if op.is_prefix else f"{sn1}{op.symbol}{sn2}"
if not linearize:
if fmt in blacklist: # Ignore expression if they are in the blacklist
continue
new_vals = ArTy()
op_eval(new_vals, vals1, vals2, N)
if is_both_constant: # if both were constant use the constant as repr instead
fmt = str(self._to_signed(new_vals[0]))
h = self.hash(new_vals)
if h not in hash_set:
if linearize:
fmt = self._try_linearize(fmt, symbols) if linearize else fmt
if fmt in blacklist: # if linearize check blacklist here
continue
logger.debug(f"[add] {fmt: <20} {h}")
hash_set.add(h)
item_count += 1
worklist.append((new_vals, fmt))
if op.commutative and do_use_blacklist and not is_both_constant:
fmt = f"{op.symbol}({name2},{name1})" if op.is_prefix else f"{sn2}{op.symbol}{sn1}"
fmt = self._try_linearize(fmt, symbols) if linearize else fmt
blacklist.add(fmt) # blacklist commutative equivalent e.g for a+b blacklist: b+a
logger.debug(f"[blacklist] {fmt}")
else:
logger.debug(f"[drop] {op.symbol}({name1},{name2})" if op.is_prefix else f"[drop] ({name1}){op.symbol}({name2})")
cur_depth += 1
except KeyboardInterrupt:
logger.info("Stop required")
# In the end
self.stop = True
t = time() - t0
print(f"Depth {cur_depth} (size:{len(worklist)}) (Time:{int(t/60)}m{t%60:.5f}s) [RAM:{self.__size_to_str(self.max_mem)}]")
self.add_entries(worklist)
if do_watch:
self.watchdog.join()
[docs]
def add_entry(self, hash: Hash, value: str) -> None:
"""
Abstract function to add an entry in the lookuptable.
:param hash: already computed hash to add
:type Hash: :py:obj:`qsynthesis.types.Hash`
:param value: expression value to add in the table
:type value: str
"""
raise NotImplementedError("Should be implemented by child class")
[docs]
def add_entries(self, worklist: List[Tuple[Hash, str]]) -> None:
"""
Add the given list of entries in the database.
:param worklist: list of entries to add
:type worklist: List[Tuple[:py:obj:`qsynthesis.types.Hash`, str]]
:returns: None
"""
raise NotImplementedError("Should be implemented by child class")
[docs]
@staticmethod
def create(filename: Union[str, Path], grammar: TritonGrammar, inputs: List[Input], constants: List[int] = []) -> 'InputOutputOracle':
"""
Create a new empty lookup table with the given initial parameters, grammars, inputs
and hash_mode.
:param filename: filename of the table to create
:param grammar: TritonGrammar object representing variables and operators
:param inputs: list of inputs on which to perform evaluation
:type inputs: List[:py:obj:`qsynthesis.types.Input`]
:param constants: list of constants used
:returns: lookuptable instance object
"""
raise NotImplementedError("Should be implemented by child class")
[docs]
@staticmethod
def load(file: Union[Path, str]) -> 'InputOutputOracle':
"""
Load the given lookup table and returns an instance object.
:param file: Database file to load
:returns: InputOutputOracle object
"""
raise NotImplementedError("Should be implemented by child class")
@staticmethod
def __size_to_str(value: int) -> str:
""" Return pretty printed representation of RAM usage for table generation """
units = [(float(1024), "Kb"), (float(1024 ** 2), "Mb"), (float(1024 ** 3), "Gb")]
for unit, s in units[::-1]:
if value / unit < 1:
continue
else: # We are on the right unit
return f"{value/unit:.2f}{s}"
return f"{value}B"