1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9from syntax import (Expr, mk_var, Node, true_term, false_term, 10 fresh_name, word32T, word8T, mk_eq, mk_word32, builtinTs) 11import syntax 12 13from target_objects import functions, pairings, trace, printout 14import sys 15import logic 16from logic import azip 17 18class Abort(Exception): 19 pass 20 21last_problem = [None] 22 23class Problem: 24 def __init__ (self, pairing, name = None): 25 if name == None: 26 name = pairing.name 27 self.name = 'Problem (%s)' % name 28 self.pairing = pairing 29 30 self.nodes = {} 31 self.vs = {} 32 self.next_node_name = 1 33 self.preds = {} 34 self.loop_data = {} 35 self.node_tags = {} 36 self.node_tag_revs = {} 37 self.inline_scripts = {} 38 self.entries = [] 39 self.outputs = {} 40 self.tarjan_order = [] 41 self.loop_var_analysis_cache = {} 42 43 self.known_eqs = {} 44 self.cached_analysis = {} 45 self.hook_tag_hints = {} 46 47 last_problem[0] = self 48 49 def fail_msg (self): 50 return 'FAILED %s (size %05d)' % (self.name, len(self.nodes)) 51 52 def alloc_node (self, tag, detail, loop_id = None, hint = None): 53 name = self.next_node_name 54 self.next_node_name = name + 1 55 56 self.node_tags[name] = (tag, detail) 57 self.node_tag_revs.setdefault ((tag, detail), []) 58 self.node_tag_revs[(tag, detail)].append (name) 59 60 if loop_id != None: 61 self.loop_data[name] = ('Mem', loop_id) 62 63 return name 64 65 def fresh_var (self, name, typ): 66 name = fresh_name (name, self.vs, typ) 67 return mk_var (name, typ) 68 69 def clone_function (self, fun, tag): 70 self.nodes = {} 71 self.vs = syntax.get_vars (fun) 72 for n in fun.reachable_nodes (): 73 self.nodes[n] = fun.nodes[n] 74 detail = (fun.name, n) 75 self.node_tags[n] = (tag, detail) 76 self.node_tag_revs.setdefault ((tag, detail), []) 77 self.node_tag_revs[(tag, detail)].append (n) 78 self.outputs[tag] = fun.outputs 79 self.entries = [(fun.entry, tag, fun.name, fun.inputs)] 80 self.next_node_name = max (self.nodes.keys () + [2]) + 1 81 self.inline_scripts[tag] = [] 82 83 def add_function (self, fun, tag, node_renames, loop_id = None): 84 if not fun.entry: 85 printout ('Aborting %s: underspecified %s' % ( 86 self.name, fun.name)) 87 raise Abort () 88 node_renames.setdefault('Ret', 'Ret') 89 node_renames.setdefault('Err', 'Err') 90 new_node_renames = {} 91 vs = syntax.get_vars (fun) 92 vs = dict ([(v, fresh_name (v, self.vs, vs[v])) for v in vs]) 93 ns = fun.reachable_nodes () 94 check_no_symbols ([fun.nodes[n] for n in ns]) 95 for n in ns: 96 assert n not in node_renames 97 node_renames[n] = self.alloc_node (tag, (fun.name, n), 98 loop_id = loop_id, hint = n) 99 new_node_renames[n] = node_renames[n] 100 for n in ns: 101 self.nodes[node_renames[n]] = syntax.copy_rename ( 102 fun.nodes[n], (vs, node_renames)) 103 104 return (new_node_renames, vs) 105 106 def add_entry_function (self, fun, tag): 107 (ns, vs) = self.add_function (fun, tag, {}) 108 109 entry = ns[fun.entry] 110 args = [(vs[v], typ) for (v, typ) in fun.inputs] 111 rets = [(vs[v], typ) for (v, typ) in fun.outputs] 112 self.entries.append((entry, tag, fun.name, args)) 113 self.outputs[tag] = rets 114 115 self.inline_scripts[tag] = [] 116 117 return (args, rets, entry) 118 119 def get_entry_details (self, tag): 120 [(e, t, fname, args)] = [(e, t, fname, args) 121 for (e, t, fname, args) in self.entries if t == tag] 122 return (e, fname, args) 123 124 def get_entry (self, tag): 125 (e, fname, args) = self.get_entry_details (tag) 126 return e 127 128 def tags (self): 129 return self.outputs.keys () 130 131 def entry_exit_renames (self, tags = None): 132 """computes the rename set of a function's formal parameters 133 to the actual input/output variable names at the various entry 134 and exit points""" 135 mk = lambda xs, ys: dict ([(x[0], y[0]) for (x, y) in 136 azip (xs, ys)]) 137 renames = {} 138 if tags == None: 139 tags = self.tags () 140 for tag in tags: 141 (_, fname, args) = self.get_entry_details (tag) 142 fun = functions[fname] 143 out = self.outputs[tag] 144 renames[tag + '_IN'] = mk (fun.inputs, args) 145 renames[tag + '_OUT'] = mk (fun.outputs, out) 146 return renames 147 148 def redirect_conts (self, reds): 149 for node in self.nodes.itervalues(): 150 if node.kind == 'Cond': 151 node.left = reds.get(node.left, node.left) 152 node.right = reds.get(node.right, node.right) 153 else: 154 node.cont = reds.get(node.cont, node.cont) 155 156 def do_analysis (self): 157 self.cached_analysis.clear () 158 self.compute_preds () 159 self.do_loop_analysis () 160 161 def mk_node_graph (self, node_subset = None): 162 if node_subset == None: 163 node_subset = self.nodes 164 return dict ([(n, [c for c in self.nodes[n].get_conts () 165 if c in node_subset]) 166 for n in node_subset]) 167 168 def do_loop_analysis (self): 169 entries = [e for (e, tag, nm, args) in self.entries] 170 self.loop_data = {} 171 172 graph = self.mk_node_graph () 173 comps = logic.tarjan (graph, entries) 174 self.tarjan_order = [] 175 176 for (head, tail) in comps: 177 self.tarjan_order.append (head) 178 self.tarjan_order.extend (tail) 179 if not tail and head not in graph[head]: 180 continue 181 trace ('Loop (%d, %s)' % (head, tail)) 182 183 loop_set = set (tail) 184 loop_set.add (head) 185 186 r = self.force_single_loop_return (head, loop_set) 187 if r != None: 188 tail.append (r) 189 loop_set.add (r) 190 self.tarjan_order.append (r) 191 self.compute_preds () 192 193 self.loop_data[head] = ('Head', loop_set) 194 for t in tail: 195 self.loop_data[t] = ('Mem', head) 196 197 # put this in first-to-last order. 198 self.tarjan_order.reverse () 199 200 def check_no_inner_loops (self): 201 for loop in self.loop_heads (): 202 check_no_inner_loop (self, loop) 203 204 def force_single_loop_return (self, head, loop_set): 205 rets = [n for n in self.preds[head] if n in loop_set] 206 if (len (rets) == 1 and rets[0] != head and 207 self.nodes[rets[0]].is_noop ()): 208 return None 209 r = self.alloc_node (self.node_tags[head][0], 210 'LoopReturn', loop_id = head) 211 self.nodes[r] = Node ('Basic', head, []) 212 for r2 in rets: 213 self.nodes[r2] = syntax.copy_rename (self.nodes[r2], 214 ({}, {head: r})) 215 return r 216 217 def splittable_points (self, n): 218 """splittable points are points which when removed, the loop 219 'splits' and ceases to be a loop. 220 221 equivalently, the set of splittable points is the intersection 222 of all sub-loops of the loop.""" 223 head = self.loop_id (n) 224 assert head != None 225 k = ('Splittables', head) 226 if k in self.cached_analysis: 227 return self.cached_analysis[k] 228 229 # check if the head point is a split (the inner loop 230 # check does exactly that) 231 if has_inner_loop (self, head): 232 head = logic.get_one_loop_splittable (self, 233 self.loop_body (head)) 234 if head == None: 235 return set () 236 237 splits = self.get_loop_splittables (head) 238 self.cached_analysis[k] = splits 239 return splits 240 241 def get_loop_splittables (self, head): 242 loop_set = self.loop_body (head) 243 splittable = dict ([(n, False) for n in loop_set]) 244 arc = [head] 245 n = head 246 while True: 247 ns = [n2 for n2 in self.nodes[n].get_conts () 248 if n2 in loop_set] 249 ns2 = [x for x in ns if x == head or x not in arc] 250 #n = ns[0] 251 n = ns2[0] 252 arc.append (n) 253 splittable[n] = True 254 if n == head: 255 break 256 last_descs = {} 257 for i in range (len (arc)): 258 last_descs[arc[i]] = i 259 def last_desc (n): 260 if n in last_descs: 261 return last_descs[n] 262 n2s = [n2 for n2 in self.nodes[n].get_conts() 263 if n2 in loop_set] 264 last_descs[n] = None 265 for n2 in n2s: 266 x = last_desc(n2) 267 if last_descs[n] == None or x >= last_descs[n]: 268 last_descs[n] = x 269 return last_descs[n] 270 for i in range (len (arc)): 271 max_arc = max ([last_desc (n) 272 for n in self.nodes[arc[i]].get_conts () 273 if n in loop_set]) 274 for j in range (i + 1, max_arc): 275 splittable[arc[j]] = False 276 return set ([n for n in splittable if splittable[n]]) 277 278 def loop_heads (self): 279 return [n for n in self.loop_data 280 if self.loop_data[n][0] == 'Head'] 281 282 def loop_id (self, n): 283 if n not in self.loop_data: 284 return None 285 elif self.loop_data[n][0] == 'Head': 286 return n 287 else: 288 assert self.loop_data[n][0] == 'Mem' 289 return self.loop_data[n][1] 290 291 def loop_body (self, n): 292 head = self.loop_id (n) 293 return self.loop_data[head][1] 294 295 def compute_preds (self): 296 self.preds = logic.compute_preds (self.nodes) 297 298 def var_dep_outputs (self, n): 299 return self.outputs[self.node_tags[n][0]] 300 301 def compute_var_dependencies (self): 302 if 'var_dependencies' in self.cached_analysis: 303 return self.cached_analysis['var_dependencies'] 304 var_deps = logic.compute_var_deps (self.nodes, 305 self.var_dep_outputs, self.preds) 306 var_deps2 = dict ([(n, dict ([(v, None) 307 for v in var_deps.get (n, [])])) 308 for n in self.nodes]) 309 self.cached_analysis['var_dependencies'] = var_deps2 310 return var_deps2 311 312 def get_loop_var_analysis (self, var_deps, n): 313 head = self.loop_id (n) 314 assert head, n 315 assert n in self.splittable_points (n) 316 loop_sort = tuple (sorted (self.loop_body (head))) 317 node_data = [(self.nodes[n2], sorted (self.preds[n]), 318 sorted (var_deps[n2].keys ())) 319 for n2 in loop_sort] 320 k = (n, loop_sort) 321 data = (node_data, n) 322 if k in self.loop_var_analysis_cache: 323 for (data2, va) in self.loop_var_analysis_cache[k]: 324 if data2 == data: 325 return va 326 va = logic.compute_loop_var_analysis (self, var_deps, n) 327 group = self.loop_var_analysis_cache.setdefault (k, []) 328 group.append ((data, va)) 329 del group[:-10] 330 return va 331 332 def save_graph (self, fname): 333 cols = mk_graph_cols (self.node_tags) 334 save_graph (self.nodes, fname, cols = cols, 335 node_tags = self.node_tags) 336 337 def save_graph_summ (self, fname): 338 node_ids = {} 339 def is_triv (n): 340 if n not in self.nodes: 341 return False 342 if len (self.preds[n]) != 1: 343 return False 344 node = self.nodes[n] 345 if node.kind == 'Basic': 346 return (True, node.cont) 347 elif node.kind == 'Cond' and node.right == 'Err': 348 return (True, node.left) 349 else: 350 return False 351 for n in self.nodes: 352 if n in node_ids: 353 continue 354 ns = [] 355 while is_triv (n): 356 ns.append (n) 357 n = is_triv (n)[1] 358 for n2 in ns: 359 node_ids[n2] = n 360 nodes = {} 361 for n in self.nodes: 362 if is_triv (n): 363 continue 364 nodes[n] = syntax.copy_rename (self.nodes[n], 365 ({}, node_ids)) 366 cols = mk_graph_cols (self.node_tags) 367 save_graph (nodes, fname, cols = cols, 368 node_tags = self.node_tags) 369 370 def serialise (self): 371 ss = ['Problem'] 372 for (n, tag, fname, inputs) in self.entries: 373 xs = ['Entry', '%d' % n, tag, fname, 374 '%d' % len (inputs)] 375 for (nm, typ) in inputs: 376 xs.append (nm) 377 typ.serialise (xs) 378 xs.append ('%d' % len (self.outputs[tag])) 379 for (nm, typ) in self.outputs[tag]: 380 xs.append (nm) 381 typ.serialise (xs) 382 ss.append (' '.join (xs)) 383 for n in self.nodes: 384 xs = ['%d' % n] 385 self.nodes[n].serialise (xs) 386 ss.append (' '.join (xs)) 387 ss.append ('EndProblem') 388 return ss 389 390 def save_serialise (self, fname): 391 ss = self.serialise () 392 f = open (fname, 'w') 393 for s in ss: 394 f.write (s + '\n') 395 f.close () 396 397 def pad_merge_points (self): 398 self.compute_preds () 399 400 arcs = [(pred, n) for n in self.preds 401 if len (self.preds[n]) > 1 402 if n in self.nodes 403 for pred in self.preds[n] 404 if (self.nodes[pred].kind != 'Basic' 405 or self.nodes[pred].upds != [])] 406 407 for (pred, n) in arcs: 408 (tag, _) = self.node_tags[pred] 409 name = self.alloc_node (tag, 'MergePadding') 410 self.nodes[name] = Node ('Basic', n, []) 411 self.nodes[pred] = syntax.copy_rename (self.nodes[pred], 412 ({}, {n: name})) 413 414 def function_call_addrs (self): 415 return [(n, self.nodes[n].fname) 416 for n in self.nodes if self.nodes[n].kind == 'Call'] 417 418 def function_calls (self): 419 return set ([fn for (n, fn) in self.function_call_addrs ()]) 420 421 def get_extensions (self): 422 if 'extensions' in self.cached_analysis: 423 return self.cached_analysis['extensions'] 424 extensions = set () 425 for node in self.nodes.itervalues (): 426 extensions.update (syntax.get_extensions (node)) 427 self.cached_analysis['extensions'] = extensions 428 return extensions 429 430 def replay_inline_script (self, tag, script): 431 for (detail, idx, fname) in script: 432 n = self.node_tag_revs[(tag, detail)][idx] 433 assert self.nodes[n].kind == 'Call', self.nodes[n] 434 assert self.nodes[n].fname == fname, self.nodes[n] 435 inline_at_point (self, n, do_analysis = False) 436 if script: 437 self.do_analysis () 438 439 def is_reachable_from (self, source, target): 440 '''discover if graph addr "target" is reachable 441 from starting node "source"''' 442 k = ('is_reachable_from', source) 443 if k in self.cached_analysis: 444 reachable = self.cached_analysis[k] 445 if target in reachable: 446 return reachable[target] 447 448 reachable = {} 449 visit = [source] 450 while visit: 451 n = visit.pop () 452 if n not in self.nodes: 453 continue 454 for n2 in self.nodes[n].get_conts (): 455 if n2 not in reachable: 456 reachable[n2] = True 457 visit.append (n2) 458 for n in list (self.nodes) + ['Ret', 'Err']: 459 if n not in reachable: 460 reachable[n] = False 461 self.cached_analysis[k] = reachable 462 return reachable[target] 463 464 def is_reachable_without (self, cutpoint, target): 465 '''discover if graph addr "target" is reachable 466 without visiting node "cutpoint" 467 (an oddity: cutpoint itself is considered reachable)''' 468 k = ('is_reachable_without', cutpoint) 469 if k in self.cached_analysis: 470 reachable = self.cached_analysis[k] 471 if target in reachable: 472 return reachable[target] 473 474 reachable = dict ([(self.get_entry (t), True) 475 for t in self.tags ()]) 476 for n in self.tarjan_order + ['Ret', 'Err']: 477 if n in reachable: 478 continue 479 reachable[n] = bool ([pred for pred in self.preds[n] 480 if pred != cutpoint 481 if reachable.get (pred) == True]) 482 self.cached_analysis[k] = reachable 483 return reachable[target] 484 485def deserialise (name, lines): 486 assert lines[0] == 'Problem', lines[0] 487 assert lines[-1] == 'EndProblem', lines[-1] 488 i = 1 489 # not easy to reconstruct pairing 490 p = Problem (pairing = None, name = name) 491 while lines[i].startswith ('Entry'): 492 bits = lines[i].split () 493 en = int (bits[1]) 494 tag = bits[2] 495 fname = bits[3] 496 (n, inputs) = syntax.parse_list (syntax.parse_lval, bits, 4) 497 (n, outputs) = syntax.parse_list (syntax.parse_lval, bits, n) 498 assert n == len (bits), (n, bits) 499 p.entries.append ((en, tag, fname, inputs)) 500 p.outputs[tag] = outputs 501 i += 1 502 for i in range (i, len (lines) - 1): 503 bits = lines[i].split () 504 n = int (bits[0]) 505 node = syntax.parse_node (bits, 1) 506 p.nodes[n] = node 507 return p 508 509# trivia 510 511def check_no_symbols (nodes): 512 import pseudo_compile 513 symbs = pseudo_compile.nodes_symbols (nodes) 514 if not symbs: 515 return 516 printout ('Aborting %s: undefined symbols %s' % (self.name, symbs)) 517 raise Abort () 518 519# printing of problem graphs 520 521def sanitise_str (s): 522 return s.replace ('"', '_').replace ("'", "_").replace (' ', '') 523 524def graph_name (nodes, node_tags, n, prev=None): 525 if type (n) == str: 526 return 't_%s_%d' % (n, prev) 527 if n not in nodes: 528 return 'unknown_%d' % n 529 if n not in node_tags: 530 ident = '%d' % n 531 else: 532 (tag, details) = node_tags[n] 533 if len (details) > 1 and logic.is_int (details[1]): 534 ident = '%d_%s_%s_0x%x' % (n, tag, 535 details[0], details[1]) 536 elif type (details) != str: 537 details = '_'.join (map (str, details)) 538 ident = '%d_%s_%s' % (n, tag, details) 539 else: 540 ident = '%d_%s_%s' % (n, tag, details) 541 ident = sanitise_str (ident) 542 node = nodes[n] 543 if node.kind == 'Call': 544 return 'fcall_%s' % ident 545 if node.kind == 'Cond': 546 return ident 547 if node.kind == 'Basic': 548 return 'ass_%s' % ident 549 assert not 'node kind understood' 550 551def graph_node_tooltip (nodes, n): 552 if n == 'Err': 553 return 'Error point' 554 if n == 'Ret': 555 return 'Return point' 556 node = nodes[n] 557 if node.kind == 'Call': 558 return "%s: call to '%s'" % (n, sanitise_str (node.fname)) 559 if node.kind == 'Cond': 560 return '%s: conditional node' % n 561 if node.kind == 'Basic': 562 var_names = [sanitise_str (x[0][0]) for x in node.upds] 563 return '%s: assignment to [%s]' % (n, ', '.join (var_names)) 564 assert not 'node kind understood' 565 566def graph_edges (nodes, n): 567 node = nodes[n] 568 if node.is_noop (): 569 return [(node.get_conts () [0], 'N')] 570 elif node.kind == 'Cond': 571 return [(node.left, 'T'), (node.right, 'F')] 572 else: 573 return [(node.cont, 'C')] 574 575def get_graph_font (n, col): 576 font = 'fontname = "Arial", fontsize = 20, penwidth=3' 577 if col: 578 font = font + ', color=%s, fontcolor=%s' % (col, col) 579 return font 580 581def get_graph_loops (nodes): 582 graph = dict ([(n, [c for c in nodes[n].get_conts () 583 if type (c) != str]) for n in nodes]) 584 graph['ENTRY'] = list (nodes) 585 comps = logic.tarjan (graph, ['ENTRY']) 586 comp_ids = {} 587 for (head, tail) in comps: 588 comp_ids[head] = head 589 for n in tail: 590 comp_ids[n] = head 591 loops = set ([(n, n2) for n in graph for n2 in graph[n] 592 if comp_ids[n] == comp_ids[n2]]) 593 return loops 594 595def make_graph (nodes, cols, node_tags = {}, entries = []): 596 graph = [] 597 graph.append ('digraph foo {') 598 599 loops = get_graph_loops (nodes) 600 601 for n in nodes: 602 n_nm = graph_name (nodes, node_tags, n) 603 f = get_graph_font (n, cols.get (n)) 604 graph.append ('%s [%s\n label="%s"\n tooltip="%s"];' % (n, 605 f, n_nm, graph_node_tooltip (nodes, n))) 606 for (c, l) in graph_edges (nodes, n): 607 if c in ['Ret', 'Err']: 608 c_nm = '%s_%s' % (c, n) 609 if c == 'Ret': 610 f2 = f + ', shape=doubleoctagon' 611 else: 612 f2 = f + ', shape=Mdiamond' 613 graph.append ('%s [label="%s", %s];' 614 % (c_nm, c, f2)) 615 else: 616 c_nm = c 617 ft = f 618 if (n, c) in loops: 619 ft = f + ', penwidth=6' 620 graph.append ('%s -> %s [label=%s, %s];' % ( 621 n, c_nm, l, ft)) 622 623 for (i, (n, tag, inps)) in enumerate (entries): 624 f = get_graph_font (n, cols.get (n)) 625 nm1 = tag + ' ENTRY_POINT' 626 nm2 = 'entry_point_%d' % i 627 graph.extend (['%s -> %s [%s];' % (nm2, n, f), 628 '%s [label = "%s", shape=none, %s];' % (nm2, nm1, f)]) 629 630 graph.append ('}') 631 return graph 632 633def print_graph (nodes, cols = {}, entries = []): 634 for line in make_graph (nodes, cols, entries): 635 print line 636 637def save_graph (nodes, fname, cols = {}, entries = [], node_tags = {}): 638 f = open (fname, 'w') 639 for line in make_graph (nodes, cols = cols, node_tags = node_tags, 640 entries = entries): 641 f.write (line + '\n') 642 f.close () 643 644def mk_graph_cols (node_tags): 645 known_cols = {'C': "forestgreen", 'ASM_adj': "darkblue", 646 'ASM': "darkorange"} 647 cols = {} 648 for n in node_tags: 649 if node_tags[n][0] in known_cols: 650 cols[n] = known_cols[node_tags[n][0]] 651 return cols 652 653def make_graph_with_eqs (p, invis = False): 654 if invis: 655 invis_s = ', style=invis' 656 else: 657 invis_s = '' 658 cols = mk_graph_cols (p.node_tags) 659 graph = make_graph (p.nodes, cols = cols) 660 graph.pop () 661 for k in p.known_eqs: 662 if k == 'Hyps': 663 continue 664 (n_vc_x, tag_x) = k 665 nm1 = graph_name (p.nodes, p.node_tags, n_vc_x[0]) 666 for (x, n_vc_y, tag_y, y, hyps) in p.known_eqs[k]: 667 nm2 = graph_name (p.nodes, p.node_tags, n_vc_y[0]) 668 graph.extend ([('%s -> %s [ dir = back, color = blue, ' 669 'penwidth = 3, weight = 0 %s ]') 670 % (nm2, nm1, invis_s)]) 671 graph.append ('}') 672 return graph 673 674def save_graph_with_eqs (p, fname = 'diagram.dot', invis = False): 675 graph = make_graph_with_eqs (p, invis = invis) 676 f = open (fname, 'w') 677 for s in graph: 678 f.write (s + '\n') 679 f.close () 680 681def get_problem_vars (p): 682 inout = set.union (* ([set(xs) for xs in p.outputs.itervalues ()] 683 + [set (args) for (_, _, _, args) in p.entries])) 684 685 vs = dict(inout) 686 for node in p.nodes.itervalues(): 687 syntax.get_node_vars(node, vs) 688 return vs 689 690def is_trivial_fun (fun): 691 for node in fun.nodes.itervalues (): 692 if node.is_noop (): 693 continue 694 if node.kind == 'Call': 695 return False 696 elif node.kind == 'Basic': 697 for (lv, v) in node.upds: 698 if v.kind not in ['Var', 'Num']: 699 return False 700 elif node.kind == 'Cond': 701 if node.cond.kind != 'Var' and node.cond not in [ 702 true_term, false_term]: 703 return False 704 return True 705 706last_alt_nodes = [0] 707 708def avail_val (vs, typ): 709 for (nm, typ2) in vs: 710 if typ2 == typ: 711 return mk_var (nm, typ2) 712 return logic.default_val (typ) 713 714def inline_at_point (p, n, do_analysis = True): 715 node = p.nodes[n] 716 if node.kind != 'Call': 717 return 718 719 f_nm = node.fname 720 fun = functions[f_nm] 721 (tag, detail) = p.node_tags[n] 722 idx = p.node_tag_revs[(tag, detail)].index (n) 723 p.inline_scripts[tag].append ((detail, idx, f_nm)) 724 725 trace ('Inlining %s into %s' % (f_nm, p.name)) 726 if n in p.loop_data: 727 trace (' inlining into loop %d!' % p.loop_id (n)) 728 729 ex = p.alloc_node (tag, (f_nm, 'RetToCaller')) 730 731 (ns, vs) = p.add_function (fun, tag, {'Ret': ex}) 732 en = ns[fun.entry] 733 734 inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs] 735 p.nodes[n] = Node ('Basic', en, azip (inp_lvs, node.args)) 736 737 out_vs = [mk_var (vs[v], typ) for (v, typ) in fun.outputs] 738 p.nodes[ex] = Node ('Basic', node.cont, azip (node.rets, out_vs)) 739 740 p.cached_analysis.clear () 741 742 if do_analysis: 743 p.do_analysis () 744 745 trace ('Problem size now %d' % len(p.nodes)) 746 sys.stdin.flush () 747 748 return ns.values () 749 750def loop_body_inner_loops (p, head, loop_body): 751 loop_set_all = set (loop_body) 752 loop_set = loop_set_all - set ([head]) 753 graph = dict([(n, [c for c in p.nodes[n].get_conts () 754 if c in loop_set]) 755 for n in loop_set_all]) 756 757 comps = logic.tarjan (graph, [head]) 758 assert sum ([1 + len (t) for (_, t) in comps]) == len (loop_set_all) 759 return [comp for comp in comps if comp[1]] 760 761def loop_inner_loops (p, head): 762 k = ('inner_loop_set', head) 763 if k in p.cached_analysis: 764 return p.cached_analysis[k] 765 res = loop_body_inner_loops (p, head, p.loop_body (head)) 766 p.cached_analysis[k] = res 767 return res 768 769def loop_heads_including_inner (p): 770 heads = p.loop_heads () 771 check = [(head, p.loop_body (head)) for head in heads] 772 while check: 773 (head, body) = check.pop () 774 comps = loop_body_inner_loops (p, head, body) 775 heads.extend ([head for (head, _) in comps]) 776 check.extend ([(head, [head] + list (body)) 777 for (head, body) in comps]) 778 return heads 779 780def check_no_inner_loop (p, head): 781 subs = loop_inner_loops (p, head) 782 if subs: 783 printout ('Aborting %s, complex loop' % p.name) 784 trace (' sub-loops %s of loop at %s' % (subs, head)) 785 for (h, _) in subs: 786 trace (' head %d tagged %s' % (h, p.node_tags[h])) 787 raise Abort () 788 789def has_inner_loop (p, head): 790 return bool (loop_inner_loops (p, head)) 791 792def fun_has_inner_loop (f): 793 p = f.as_problem (Problem) 794 p.do_analysis () 795 return bool ([head for head in p.loop_heads () 796 if has_inner_loop (p, head)]) 797 798def loop_var_analysis (p, head, tail): 799 # getting the set of variables that go round the loop 800 nodes = set (tail) 801 nodes.add (head) 802 used_vs = set ([]) 803 created_vs_at = {} 804 visit = [] 805 806 def process_node (n, created): 807 if p.nodes[n].is_noop (): 808 lvals = set ([]) 809 else: 810 vs = syntax.get_node_rvals (p.nodes[n]) 811 for rv in vs.iteritems (): 812 if rv not in created: 813 used_vs.add (rv) 814 lvals = set (p.nodes[n].get_lvals ()) 815 816 created = set.union (created, lvals) 817 created_vs_at[n] = created 818 819 visit.extend (p.nodes[n].get_conts ()) 820 821 process_node (head, set ([])) 822 823 while visit: 824 n = visit.pop () 825 if (n not in nodes) or (n in created_vs_at): 826 continue 827 if not all ([pr in created_vs_at for pr in p.preds[n]]): 828 continue 829 830 pre_created = [created_vs_at[pr] for pr in p.preds[n]] 831 process_node (n, set.union (* pre_created)) 832 833 final_pre_created = [created_vs_at[pr] for pr in p.preds[head] 834 if pr in nodes] 835 created = set.union (* final_pre_created) 836 837 loop_vs = set.intersection (created, used_vs) 838 trace ('Loop vars at head: %s' % loop_vs) 839 840 return loop_vs 841 842 843