# * Copyright 2015, NICTA # * # * This software may be distributed and modified according to the terms of # * the BSD 2-Clause license. Note that NO WARRANTY is provided. # * See "LICENSE_BSD2.txt" for details. # * # * @TAG(NICTA_BSD) from syntax import (Expr, mk_var, Node, true_term, false_term, fresh_name, word32T, word8T, mk_eq, mk_word32, builtinTs) import syntax from target_objects import functions, pairings, trace, printout import sys import logic from logic import azip class Abort(Exception): pass last_problem = [None] class Problem: def __init__ (self, pairing, name = None): if name == None: name = pairing.name self.name = 'Problem (%s)' % name self.pairing = pairing self.nodes = {} self.vs = {} self.next_node_name = 1 self.preds = {} self.loop_data = {} self.node_tags = {} self.node_tag_revs = {} self.inline_scripts = {} self.entries = [] self.outputs = {} self.tarjan_order = [] self.loop_var_analysis_cache = {} self.known_eqs = {} self.cached_analysis = {} self.hook_tag_hints = {} last_problem[0] = self def fail_msg (self): return 'FAILED %s (size %05d)' % (self.name, len(self.nodes)) def alloc_node (self, tag, detail, loop_id = None, hint = None): name = self.next_node_name self.next_node_name = name + 1 self.node_tags[name] = (tag, detail) self.node_tag_revs.setdefault ((tag, detail), []) self.node_tag_revs[(tag, detail)].append (name) if loop_id != None: self.loop_data[name] = ('Mem', loop_id) return name def fresh_var (self, name, typ): name = fresh_name (name, self.vs, typ) return mk_var (name, typ) def clone_function (self, fun, tag): self.nodes = {} self.vs = syntax.get_vars (fun) for n in fun.reachable_nodes (): self.nodes[n] = fun.nodes[n] detail = (fun.name, n) self.node_tags[n] = (tag, detail) self.node_tag_revs.setdefault ((tag, detail), []) self.node_tag_revs[(tag, detail)].append (n) self.outputs[tag] = fun.outputs self.entries = [(fun.entry, tag, fun.name, fun.inputs)] self.next_node_name = max (self.nodes.keys () + [2]) + 1 self.inline_scripts[tag] = [] def add_function (self, fun, tag, node_renames, loop_id = None): if not fun.entry: printout ('Aborting %s: underspecified %s' % ( self.name, fun.name)) raise Abort () node_renames.setdefault('Ret', 'Ret') node_renames.setdefault('Err', 'Err') new_node_renames = {} vs = syntax.get_vars (fun) vs = dict ([(v, fresh_name (v, self.vs, vs[v])) for v in vs]) ns = fun.reachable_nodes () check_no_symbols ([fun.nodes[n] for n in ns]) for n in ns: assert n not in node_renames node_renames[n] = self.alloc_node (tag, (fun.name, n), loop_id = loop_id, hint = n) new_node_renames[n] = node_renames[n] for n in ns: self.nodes[node_renames[n]] = syntax.copy_rename ( fun.nodes[n], (vs, node_renames)) return (new_node_renames, vs) def add_entry_function (self, fun, tag): (ns, vs) = self.add_function (fun, tag, {}) entry = ns[fun.entry] args = [(vs[v], typ) for (v, typ) in fun.inputs] rets = [(vs[v], typ) for (v, typ) in fun.outputs] self.entries.append((entry, tag, fun.name, args)) self.outputs[tag] = rets self.inline_scripts[tag] = [] return (args, rets, entry) def get_entry_details (self, tag): [(e, t, fname, args)] = [(e, t, fname, args) for (e, t, fname, args) in self.entries if t == tag] return (e, fname, args) def get_entry (self, tag): (e, fname, args) = self.get_entry_details (tag) return e def tags (self): return self.outputs.keys () def entry_exit_renames (self, tags = None): """computes the rename set of a function's formal parameters to the actual input/output variable names at the various entry and exit points""" mk = lambda xs, ys: dict ([(x[0], y[0]) for (x, y) in azip (xs, ys)]) renames = {} if tags == None: tags = self.tags () for tag in tags: (_, fname, args) = self.get_entry_details (tag) fun = functions[fname] out = self.outputs[tag] renames[tag + '_IN'] = mk (fun.inputs, args) renames[tag + '_OUT'] = mk (fun.outputs, out) return renames def redirect_conts (self, reds): for node in self.nodes.itervalues(): if node.kind == 'Cond': node.left = reds.get(node.left, node.left) node.right = reds.get(node.right, node.right) else: node.cont = reds.get(node.cont, node.cont) def do_analysis (self): self.cached_analysis.clear () self.compute_preds () self.do_loop_analysis () def mk_node_graph (self, node_subset = None): if node_subset == None: node_subset = self.nodes return dict ([(n, [c for c in self.nodes[n].get_conts () if c in node_subset]) for n in node_subset]) def do_loop_analysis (self): entries = [e for (e, tag, nm, args) in self.entries] self.loop_data = {} graph = self.mk_node_graph () comps = logic.tarjan (graph, entries) self.tarjan_order = [] for (head, tail) in comps: self.tarjan_order.append (head) self.tarjan_order.extend (tail) if not tail and head not in graph[head]: continue trace ('Loop (%d, %s)' % (head, tail)) loop_set = set (tail) loop_set.add (head) r = self.force_single_loop_return (head, loop_set) if r != None: tail.append (r) loop_set.add (r) self.tarjan_order.append (r) self.compute_preds () self.loop_data[head] = ('Head', loop_set) for t in tail: self.loop_data[t] = ('Mem', head) # put this in first-to-last order. self.tarjan_order.reverse () def check_no_inner_loops (self): for loop in self.loop_heads (): check_no_inner_loop (self, loop) def force_single_loop_return (self, head, loop_set): rets = [n for n in self.preds[head] if n in loop_set] if (len (rets) == 1 and rets[0] != head and self.nodes[rets[0]].is_noop ()): return None r = self.alloc_node (self.node_tags[head][0], 'LoopReturn', loop_id = head) self.nodes[r] = Node ('Basic', head, []) for r2 in rets: self.nodes[r2] = syntax.copy_rename (self.nodes[r2], ({}, {head: r})) return r def splittable_points (self, n): """splittable points are points which when removed, the loop 'splits' and ceases to be a loop. equivalently, the set of splittable points is the intersection of all sub-loops of the loop.""" head = self.loop_id (n) assert head != None k = ('Splittables', head) if k in self.cached_analysis: return self.cached_analysis[k] # check if the head point is a split (the inner loop # check does exactly that) if has_inner_loop (self, head): head = logic.get_one_loop_splittable (self, self.loop_body (head)) if head == None: return set () splits = self.get_loop_splittables (head) self.cached_analysis[k] = splits return splits def get_loop_splittables (self, head): loop_set = self.loop_body (head) splittable = dict ([(n, False) for n in loop_set]) arc = [head] n = head while True: ns = [n2 for n2 in self.nodes[n].get_conts () if n2 in loop_set] ns2 = [x for x in ns if x == head or x not in arc] #n = ns[0] n = ns2[0] arc.append (n) splittable[n] = True if n == head: break last_descs = {} for i in range (len (arc)): last_descs[arc[i]] = i def last_desc (n): if n in last_descs: return last_descs[n] n2s = [n2 for n2 in self.nodes[n].get_conts() if n2 in loop_set] last_descs[n] = None for n2 in n2s: x = last_desc(n2) if last_descs[n] == None or x >= last_descs[n]: last_descs[n] = x return last_descs[n] for i in range (len (arc)): max_arc = max ([last_desc (n) for n in self.nodes[arc[i]].get_conts () if n in loop_set]) for j in range (i + 1, max_arc): splittable[arc[j]] = False return set ([n for n in splittable if splittable[n]]) def loop_heads (self): return [n for n in self.loop_data if self.loop_data[n][0] == 'Head'] def loop_id (self, n): if n not in self.loop_data: return None elif self.loop_data[n][0] == 'Head': return n else: assert self.loop_data[n][0] == 'Mem' return self.loop_data[n][1] def loop_body (self, n): head = self.loop_id (n) return self.loop_data[head][1] def compute_preds (self): self.preds = logic.compute_preds (self.nodes) def var_dep_outputs (self, n): return self.outputs[self.node_tags[n][0]] def compute_var_dependencies (self): if 'var_dependencies' in self.cached_analysis: return self.cached_analysis['var_dependencies'] var_deps = logic.compute_var_deps (self.nodes, self.var_dep_outputs, self.preds) var_deps2 = dict ([(n, dict ([(v, None) for v in var_deps.get (n, [])])) for n in self.nodes]) self.cached_analysis['var_dependencies'] = var_deps2 return var_deps2 def get_loop_var_analysis (self, var_deps, n): head = self.loop_id (n) assert head, n assert n in self.splittable_points (n) loop_sort = tuple (sorted (self.loop_body (head))) node_data = [(self.nodes[n2], sorted (self.preds[n]), sorted (var_deps[n2].keys ())) for n2 in loop_sort] k = (n, loop_sort) data = (node_data, n) if k in self.loop_var_analysis_cache: for (data2, va) in self.loop_var_analysis_cache[k]: if data2 == data: return va va = logic.compute_loop_var_analysis (self, var_deps, n) group = self.loop_var_analysis_cache.setdefault (k, []) group.append ((data, va)) del group[:-10] return va def save_graph (self, fname): cols = mk_graph_cols (self.node_tags) save_graph (self.nodes, fname, cols = cols, node_tags = self.node_tags) def save_graph_summ (self, fname): node_ids = {} def is_triv (n): if n not in self.nodes: return False if len (self.preds[n]) != 1: return False node = self.nodes[n] if node.kind == 'Basic': return (True, node.cont) elif node.kind == 'Cond' and node.right == 'Err': return (True, node.left) else: return False for n in self.nodes: if n in node_ids: continue ns = [] while is_triv (n): ns.append (n) n = is_triv (n)[1] for n2 in ns: node_ids[n2] = n nodes = {} for n in self.nodes: if is_triv (n): continue nodes[n] = syntax.copy_rename (self.nodes[n], ({}, node_ids)) cols = mk_graph_cols (self.node_tags) save_graph (nodes, fname, cols = cols, node_tags = self.node_tags) def serialise (self): ss = ['Problem'] for (n, tag, fname, inputs) in self.entries: xs = ['Entry', '%d' % n, tag, fname, '%d' % len (inputs)] for (nm, typ) in inputs: xs.append (nm) typ.serialise (xs) xs.append ('%d' % len (self.outputs[tag])) for (nm, typ) in self.outputs[tag]: xs.append (nm) typ.serialise (xs) ss.append (' '.join (xs)) for n in self.nodes: xs = ['%d' % n] self.nodes[n].serialise (xs) ss.append (' '.join (xs)) ss.append ('EndProblem') return ss def save_serialise (self, fname): ss = self.serialise () f = open (fname, 'w') for s in ss: f.write (s + '\n') f.close () def pad_merge_points (self): self.compute_preds () arcs = [(pred, n) for n in self.preds if len (self.preds[n]) > 1 if n in self.nodes for pred in self.preds[n] if (self.nodes[pred].kind != 'Basic' or self.nodes[pred].upds != [])] for (pred, n) in arcs: (tag, _) = self.node_tags[pred] name = self.alloc_node (tag, 'MergePadding') self.nodes[name] = Node ('Basic', n, []) self.nodes[pred] = syntax.copy_rename (self.nodes[pred], ({}, {n: name})) def function_call_addrs (self): return [(n, self.nodes[n].fname) for n in self.nodes if self.nodes[n].kind == 'Call'] def function_calls (self): return set ([fn for (n, fn) in self.function_call_addrs ()]) def get_extensions (self): if 'extensions' in self.cached_analysis: return self.cached_analysis['extensions'] extensions = set () for node in self.nodes.itervalues (): extensions.update (syntax.get_extensions (node)) self.cached_analysis['extensions'] = extensions return extensions def replay_inline_script (self, tag, script): for (detail, idx, fname) in script: n = self.node_tag_revs[(tag, detail)][idx] assert self.nodes[n].kind == 'Call', self.nodes[n] assert self.nodes[n].fname == fname, self.nodes[n] inline_at_point (self, n, do_analysis = False) if script: self.do_analysis () def is_reachable_from (self, source, target): '''discover if graph addr "target" is reachable from starting node "source"''' k = ('is_reachable_from', source) if k in self.cached_analysis: reachable = self.cached_analysis[k] if target in reachable: return reachable[target] reachable = {} visit = [source] while visit: n = visit.pop () if n not in self.nodes: continue for n2 in self.nodes[n].get_conts (): if n2 not in reachable: reachable[n2] = True visit.append (n2) for n in list (self.nodes) + ['Ret', 'Err']: if n not in reachable: reachable[n] = False self.cached_analysis[k] = reachable return reachable[target] def is_reachable_without (self, cutpoint, target): '''discover if graph addr "target" is reachable without visiting node "cutpoint" (an oddity: cutpoint itself is considered reachable)''' k = ('is_reachable_without', cutpoint) if k in self.cached_analysis: reachable = self.cached_analysis[k] if target in reachable: return reachable[target] reachable = dict ([(self.get_entry (t), True) for t in self.tags ()]) for n in self.tarjan_order + ['Ret', 'Err']: if n in reachable: continue reachable[n] = bool ([pred for pred in self.preds[n] if pred != cutpoint if reachable.get (pred) == True]) self.cached_analysis[k] = reachable return reachable[target] def deserialise (name, lines): assert lines[0] == 'Problem', lines[0] assert lines[-1] == 'EndProblem', lines[-1] i = 1 # not easy to reconstruct pairing p = Problem (pairing = None, name = name) while lines[i].startswith ('Entry'): bits = lines[i].split () en = int (bits[1]) tag = bits[2] fname = bits[3] (n, inputs) = syntax.parse_list (syntax.parse_lval, bits, 4) (n, outputs) = syntax.parse_list (syntax.parse_lval, bits, n) assert n == len (bits), (n, bits) p.entries.append ((en, tag, fname, inputs)) p.outputs[tag] = outputs i += 1 for i in range (i, len (lines) - 1): bits = lines[i].split () n = int (bits[0]) node = syntax.parse_node (bits, 1) p.nodes[n] = node return p # trivia def check_no_symbols (nodes): import pseudo_compile symbs = pseudo_compile.nodes_symbols (nodes) if not symbs: return printout ('Aborting %s: undefined symbols %s' % (self.name, symbs)) raise Abort () # printing of problem graphs def sanitise_str (s): return s.replace ('"', '_').replace ("'", "_").replace (' ', '') def graph_name (nodes, node_tags, n, prev=None): if type (n) == str: return 't_%s_%d' % (n, prev) if n not in nodes: return 'unknown_%d' % n if n not in node_tags: ident = '%d' % n else: (tag, details) = node_tags[n] if len (details) > 1 and logic.is_int (details[1]): ident = '%d_%s_%s_0x%x' % (n, tag, details[0], details[1]) elif type (details) != str: details = '_'.join (map (str, details)) ident = '%d_%s_%s' % (n, tag, details) else: ident = '%d_%s_%s' % (n, tag, details) ident = sanitise_str (ident) node = nodes[n] if node.kind == 'Call': return 'fcall_%s' % ident if node.kind == 'Cond': return ident if node.kind == 'Basic': return 'ass_%s' % ident assert not 'node kind understood' def graph_node_tooltip (nodes, n): if n == 'Err': return 'Error point' if n == 'Ret': return 'Return point' node = nodes[n] if node.kind == 'Call': return "%s: call to '%s'" % (n, sanitise_str (node.fname)) if node.kind == 'Cond': return '%s: conditional node' % n if node.kind == 'Basic': var_names = [sanitise_str (x[0][0]) for x in node.upds] return '%s: assignment to [%s]' % (n, ', '.join (var_names)) assert not 'node kind understood' def graph_edges (nodes, n): node = nodes[n] if node.is_noop (): return [(node.get_conts () [0], 'N')] elif node.kind == 'Cond': return [(node.left, 'T'), (node.right, 'F')] else: return [(node.cont, 'C')] def get_graph_font (n, col): font = 'fontname = "Arial", fontsize = 20, penwidth=3' if col: font = font + ', color=%s, fontcolor=%s' % (col, col) return font def get_graph_loops (nodes): graph = dict ([(n, [c for c in nodes[n].get_conts () if type (c) != str]) for n in nodes]) graph['ENTRY'] = list (nodes) comps = logic.tarjan (graph, ['ENTRY']) comp_ids = {} for (head, tail) in comps: comp_ids[head] = head for n in tail: comp_ids[n] = head loops = set ([(n, n2) for n in graph for n2 in graph[n] if comp_ids[n] == comp_ids[n2]]) return loops def make_graph (nodes, cols, node_tags = {}, entries = []): graph = [] graph.append ('digraph foo {') loops = get_graph_loops (nodes) for n in nodes: n_nm = graph_name (nodes, node_tags, n) f = get_graph_font (n, cols.get (n)) graph.append ('%s [%s\n label="%s"\n tooltip="%s"];' % (n, f, n_nm, graph_node_tooltip (nodes, n))) for (c, l) in graph_edges (nodes, n): if c in ['Ret', 'Err']: c_nm = '%s_%s' % (c, n) if c == 'Ret': f2 = f + ', shape=doubleoctagon' else: f2 = f + ', shape=Mdiamond' graph.append ('%s [label="%s", %s];' % (c_nm, c, f2)) else: c_nm = c ft = f if (n, c) in loops: ft = f + ', penwidth=6' graph.append ('%s -> %s [label=%s, %s];' % ( n, c_nm, l, ft)) for (i, (n, tag, inps)) in enumerate (entries): f = get_graph_font (n, cols.get (n)) nm1 = tag + ' ENTRY_POINT' nm2 = 'entry_point_%d' % i graph.extend (['%s -> %s [%s];' % (nm2, n, f), '%s [label = "%s", shape=none, %s];' % (nm2, nm1, f)]) graph.append ('}') return graph def print_graph (nodes, cols = {}, entries = []): for line in make_graph (nodes, cols, entries): print line def save_graph (nodes, fname, cols = {}, entries = [], node_tags = {}): f = open (fname, 'w') for line in make_graph (nodes, cols = cols, node_tags = node_tags, entries = entries): f.write (line + '\n') f.close () def mk_graph_cols (node_tags): known_cols = {'C': "forestgreen", 'ASM_adj': "darkblue", 'ASM': "darkorange"} cols = {} for n in node_tags: if node_tags[n][0] in known_cols: cols[n] = known_cols[node_tags[n][0]] return cols def make_graph_with_eqs (p, invis = False): if invis: invis_s = ', style=invis' else: invis_s = '' cols = mk_graph_cols (p.node_tags) graph = make_graph (p.nodes, cols = cols) graph.pop () for k in p.known_eqs: if k == 'Hyps': continue (n_vc_x, tag_x) = k nm1 = graph_name (p.nodes, p.node_tags, n_vc_x[0]) for (x, n_vc_y, tag_y, y, hyps) in p.known_eqs[k]: nm2 = graph_name (p.nodes, p.node_tags, n_vc_y[0]) graph.extend ([('%s -> %s [ dir = back, color = blue, ' 'penwidth = 3, weight = 0 %s ]') % (nm2, nm1, invis_s)]) graph.append ('}') return graph def save_graph_with_eqs (p, fname = 'diagram.dot', invis = False): graph = make_graph_with_eqs (p, invis = invis) f = open (fname, 'w') for s in graph: f.write (s + '\n') f.close () def get_problem_vars (p): inout = set.union (* ([set(xs) for xs in p.outputs.itervalues ()] + [set (args) for (_, _, _, args) in p.entries])) vs = dict(inout) for node in p.nodes.itervalues(): syntax.get_node_vars(node, vs) return vs def is_trivial_fun (fun): for node in fun.nodes.itervalues (): if node.is_noop (): continue if node.kind == 'Call': return False elif node.kind == 'Basic': for (lv, v) in node.upds: if v.kind not in ['Var', 'Num']: return False elif node.kind == 'Cond': if node.cond.kind != 'Var' and node.cond not in [ true_term, false_term]: return False return True last_alt_nodes = [0] def avail_val (vs, typ): for (nm, typ2) in vs: if typ2 == typ: return mk_var (nm, typ2) return logic.default_val (typ) def inline_at_point (p, n, do_analysis = True): node = p.nodes[n] if node.kind != 'Call': return f_nm = node.fname fun = functions[f_nm] (tag, detail) = p.node_tags[n] idx = p.node_tag_revs[(tag, detail)].index (n) p.inline_scripts[tag].append ((detail, idx, f_nm)) trace ('Inlining %s into %s' % (f_nm, p.name)) if n in p.loop_data: trace (' inlining into loop %d!' % p.loop_id (n)) ex = p.alloc_node (tag, (f_nm, 'RetToCaller')) (ns, vs) = p.add_function (fun, tag, {'Ret': ex}) en = ns[fun.entry] inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs] p.nodes[n] = Node ('Basic', en, azip (inp_lvs, node.args)) out_vs = [mk_var (vs[v], typ) for (v, typ) in fun.outputs] p.nodes[ex] = Node ('Basic', node.cont, azip (node.rets, out_vs)) p.cached_analysis.clear () if do_analysis: p.do_analysis () trace ('Problem size now %d' % len(p.nodes)) sys.stdin.flush () return ns.values () def loop_body_inner_loops (p, head, loop_body): loop_set_all = set (loop_body) loop_set = loop_set_all - set ([head]) graph = dict([(n, [c for c in p.nodes[n].get_conts () if c in loop_set]) for n in loop_set_all]) comps = logic.tarjan (graph, [head]) assert sum ([1 + len (t) for (_, t) in comps]) == len (loop_set_all) return [comp for comp in comps if comp[1]] def loop_inner_loops (p, head): k = ('inner_loop_set', head) if k in p.cached_analysis: return p.cached_analysis[k] res = loop_body_inner_loops (p, head, p.loop_body (head)) p.cached_analysis[k] = res return res def loop_heads_including_inner (p): heads = p.loop_heads () check = [(head, p.loop_body (head)) for head in heads] while check: (head, body) = check.pop () comps = loop_body_inner_loops (p, head, body) heads.extend ([head for (head, _) in comps]) check.extend ([(head, [head] + list (body)) for (head, body) in comps]) return heads def check_no_inner_loop (p, head): subs = loop_inner_loops (p, head) if subs: printout ('Aborting %s, complex loop' % p.name) trace (' sub-loops %s of loop at %s' % (subs, head)) for (h, _) in subs: trace (' head %d tagged %s' % (h, p.node_tags[h])) raise Abort () def has_inner_loop (p, head): return bool (loop_inner_loops (p, head)) def fun_has_inner_loop (f): p = f.as_problem (Problem) p.do_analysis () return bool ([head for head in p.loop_heads () if has_inner_loop (p, head)]) def loop_var_analysis (p, head, tail): # getting the set of variables that go round the loop nodes = set (tail) nodes.add (head) used_vs = set ([]) created_vs_at = {} visit = [] def process_node (n, created): if p.nodes[n].is_noop (): lvals = set ([]) else: vs = syntax.get_node_rvals (p.nodes[n]) for rv in vs.iteritems (): if rv not in created: used_vs.add (rv) lvals = set (p.nodes[n].get_lvals ()) created = set.union (created, lvals) created_vs_at[n] = created visit.extend (p.nodes[n].get_conts ()) process_node (head, set ([])) while visit: n = visit.pop () if (n not in nodes) or (n in created_vs_at): continue if not all ([pr in created_vs_at for pr in p.preds[n]]): continue pre_created = [created_vs_at[pr] for pr in p.preds[n]] process_node (n, set.union (* pre_created)) final_pre_created = [created_vs_at[pr] for pr in p.preds[head] if pr in nodes] created = set.union (* final_pre_created) loop_vs = set.intersection (created, used_vs) trace ('Loop vars at head: %s' % loop_vs) return loop_vs