1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9from syntax import (Expr, mk_var, Node, true_term, false_term,
10  fresh_name, word32T, word8T, mk_eq, mk_word32, builtinTs)
11import syntax
12
13from target_objects import functions, pairings, trace, printout
14import sys
15import logic
16from logic import azip
17
18class Abort(Exception):
19	pass
20
21last_problem = [None]
22
23class Problem:
24	def __init__ (self, pairing, name = None):
25		if name == None:
26			name = pairing.name
27		self.name = 'Problem (%s)' % name
28		self.pairing = pairing
29
30		self.nodes = {}
31		self.vs = {}
32		self.next_node_name = 1
33		self.preds = {}
34		self.loop_data = {}
35		self.node_tags = {}
36		self.node_tag_revs = {}
37		self.inline_scripts = {}
38		self.entries = []
39		self.outputs = {}
40		self.tarjan_order = []
41		self.loop_var_analysis_cache = {}
42
43		self.known_eqs = {}
44		self.cached_analysis = {}
45		self.hook_tag_hints = {}
46
47		last_problem[0] = self
48
49	def fail_msg (self):
50		return 'FAILED %s (size %05d)' % (self.name, len(self.nodes))
51
52	def alloc_node (self, tag, detail, loop_id = None, hint = None):
53		name = self.next_node_name
54		self.next_node_name = name + 1
55
56		self.node_tags[name] = (tag, detail)
57		self.node_tag_revs.setdefault ((tag, detail), [])
58		self.node_tag_revs[(tag, detail)].append (name)
59
60		if loop_id != None:
61			self.loop_data[name] = ('Mem', loop_id)
62
63		return name
64
65	def fresh_var (self, name, typ):
66		name = fresh_name (name, self.vs, typ)
67		return mk_var (name, typ)
68
69	def clone_function (self, fun, tag):
70		self.nodes = {}
71		self.vs = syntax.get_vars (fun)
72		for n in fun.reachable_nodes ():
73			self.nodes[n] = fun.nodes[n]
74			detail = (fun.name, n)
75			self.node_tags[n] = (tag, detail)
76			self.node_tag_revs.setdefault ((tag, detail), [])
77			self.node_tag_revs[(tag, detail)].append (n)
78		self.outputs[tag] = fun.outputs
79		self.entries = [(fun.entry, tag, fun.name, fun.inputs)]
80		self.next_node_name = max (self.nodes.keys () + [2]) + 1
81		self.inline_scripts[tag] = []
82
83	def add_function (self, fun, tag, node_renames, loop_id = None):
84		if not fun.entry:
85			printout ('Aborting %s: underspecified %s' % (
86				self.name, fun.name))
87			raise Abort ()
88		node_renames.setdefault('Ret', 'Ret')
89		node_renames.setdefault('Err', 'Err')
90		new_node_renames = {}
91		vs = syntax.get_vars (fun)
92		vs = dict ([(v, fresh_name (v, self.vs, vs[v])) for v in vs])
93		ns = fun.reachable_nodes ()
94		check_no_symbols ([fun.nodes[n] for n in ns])
95		for n in ns:
96			assert n not in node_renames
97			node_renames[n] = self.alloc_node (tag, (fun.name, n),
98				loop_id = loop_id, hint = n)
99			new_node_renames[n] = node_renames[n]
100		for n in ns:
101			self.nodes[node_renames[n]] = syntax.copy_rename (
102				fun.nodes[n], (vs, node_renames))
103
104		return (new_node_renames, vs)
105
106	def add_entry_function (self, fun, tag):
107		(ns, vs) = self.add_function (fun, tag, {})
108
109		entry = ns[fun.entry]
110		args = [(vs[v], typ) for (v, typ) in fun.inputs]
111		rets = [(vs[v], typ) for (v, typ) in fun.outputs]
112		self.entries.append((entry, tag, fun.name, args))
113		self.outputs[tag] = rets
114
115		self.inline_scripts[tag] = []
116
117		return (args, rets, entry)
118
119	def get_entry_details (self, tag):
120		[(e, t, fname, args)] = [(e, t, fname, args)
121			for (e, t, fname, args) in self.entries if t == tag]
122		return (e, fname, args)
123
124	def get_entry (self, tag):
125		(e, fname, args) = self.get_entry_details (tag)
126		return e
127
128	def tags (self):
129		return self.outputs.keys ()
130
131	def entry_exit_renames (self, tags = None):
132		"""computes the rename set of a function's formal parameters
133		to the actual input/output variable names at the various entry
134		and exit points"""
135		mk = lambda xs, ys: dict ([(x[0], y[0]) for (x, y) in
136			azip (xs, ys)])
137		renames = {}
138		if tags == None:
139			tags = self.tags ()
140		for tag in tags:
141			(_, fname, args) = self.get_entry_details (tag)
142			fun = functions[fname]
143			out = self.outputs[tag]
144			renames[tag + '_IN'] = mk (fun.inputs, args)
145			renames[tag + '_OUT'] = mk (fun.outputs, out)
146		return renames
147
148	def redirect_conts (self, reds):
149		for node in self.nodes.itervalues():
150			if node.kind == 'Cond':
151				node.left = reds.get(node.left, node.left)
152				node.right = reds.get(node.right, node.right)
153			else:
154				node.cont = reds.get(node.cont, node.cont)
155
156	def do_analysis (self):
157		self.cached_analysis.clear ()
158		self.compute_preds ()
159		self.do_loop_analysis ()
160
161	def mk_node_graph (self, node_subset = None):
162		if node_subset == None:
163			node_subset = self.nodes
164		return dict ([(n, [c for c in self.nodes[n].get_conts ()
165				if c in node_subset])
166			for n in node_subset])
167
168	def do_loop_analysis (self):
169		entries = [e for (e, tag, nm, args) in self.entries]
170		self.loop_data = {}
171
172		graph = self.mk_node_graph ()
173		comps = logic.tarjan (graph, entries)
174		self.tarjan_order = []
175
176		for (head, tail) in comps:
177			self.tarjan_order.append (head)
178			self.tarjan_order.extend (tail)
179			if not tail and head not in graph[head]:
180				continue
181			trace ('Loop (%d, %s)' % (head, tail))
182
183			loop_set = set (tail)
184			loop_set.add (head)
185
186			r = self.force_single_loop_return (head, loop_set)
187			if r != None:
188				tail.append (r)
189				loop_set.add (r)
190				self.tarjan_order.append (r)
191				self.compute_preds ()
192
193			self.loop_data[head] = ('Head', loop_set)
194			for t in tail:
195				self.loop_data[t] = ('Mem', head)
196
197		# put this in first-to-last order.
198		self.tarjan_order.reverse ()
199
200	def check_no_inner_loops (self):
201		for loop in self.loop_heads ():
202			check_no_inner_loop (self, loop)
203
204	def force_single_loop_return (self, head, loop_set):
205		rets = [n for n in self.preds[head] if n in loop_set]
206		if (len (rets) == 1 and rets[0] != head and
207				self.nodes[rets[0]].is_noop ()):
208			return None
209		r = self.alloc_node (self.node_tags[head][0],
210			'LoopReturn', loop_id = head)
211		self.nodes[r] = Node ('Basic', head, [])
212		for r2 in rets:
213			self.nodes[r2] = syntax.copy_rename (self.nodes[r2],
214				({}, {head: r}))
215		return r
216
217	def splittable_points (self, n):
218		"""splittable points are points which when removed, the loop
219		'splits' and ceases to be a loop.
220
221		equivalently, the set of splittable points is the intersection
222		of all sub-loops of the loop."""
223		head = self.loop_id (n)
224		assert head != None
225		k = ('Splittables', head)
226		if k in self.cached_analysis:
227			return self.cached_analysis[k]
228
229		# check if the head point is a split (the inner loop
230		# check does exactly that)
231		if has_inner_loop (self, head):
232			head = logic.get_one_loop_splittable (self,
233				self.loop_body (head))
234			if head == None:
235				return set ()
236
237		splits = self.get_loop_splittables (head)
238		self.cached_analysis[k] = splits
239		return splits
240
241	def get_loop_splittables (self, head):
242		loop_set = self.loop_body (head)
243		splittable = dict ([(n, False) for n in loop_set])
244		arc = [head]
245		n = head
246		while True:
247			ns = [n2 for n2 in self.nodes[n].get_conts ()
248				if n2 in loop_set]
249			ns2 = [x for x in ns if x == head or x not in arc]
250			#n = ns[0]
251			n = ns2[0]
252			arc.append (n)
253			splittable[n] = True
254			if n == head:
255				break
256		last_descs = {}
257		for i in range (len (arc)):
258			last_descs[arc[i]] = i
259		def last_desc (n):
260			if n in last_descs:
261				return last_descs[n]
262			n2s = [n2 for n2 in self.nodes[n].get_conts()
263				if n2 in loop_set]
264			last_descs[n] = None
265			for n2 in n2s:
266			  x = last_desc(n2)
267			  if last_descs[n] == None or x >= last_descs[n]:
268			    last_descs[n] = x
269			return last_descs[n]
270		for i in range (len (arc)):
271			max_arc = max ([last_desc (n)
272				for n in self.nodes[arc[i]].get_conts ()
273				if n in loop_set])
274			for j in range (i + 1, max_arc):
275				splittable[arc[j]] = False
276		return set ([n for n in splittable if splittable[n]])
277
278	def loop_heads (self):
279		return [n for n in self.loop_data
280			if self.loop_data[n][0] == 'Head']
281
282	def loop_id (self, n):
283		if n not in self.loop_data:
284			return None
285		elif self.loop_data[n][0] == 'Head':
286			return n
287		else:
288			assert self.loop_data[n][0] == 'Mem'
289			return self.loop_data[n][1]
290
291	def loop_body (self, n):
292		head = self.loop_id (n)
293		return self.loop_data[head][1]
294
295	def compute_preds (self):
296		self.preds = logic.compute_preds (self.nodes)
297
298	def var_dep_outputs (self, n):
299		return self.outputs[self.node_tags[n][0]]
300
301	def compute_var_dependencies (self):
302		if 'var_dependencies' in self.cached_analysis:
303			return self.cached_analysis['var_dependencies']
304		var_deps = logic.compute_var_deps (self.nodes,
305			self.var_dep_outputs, self.preds)
306		var_deps2 = dict ([(n, dict ([(v, None)
307				for v in var_deps.get (n, [])]))
308			for n in self.nodes])
309		self.cached_analysis['var_dependencies'] = var_deps2
310		return var_deps2
311
312	def get_loop_var_analysis (self, var_deps, n):
313		head = self.loop_id (n)
314		assert head, n
315		assert n in self.splittable_points (n)
316		loop_sort = tuple (sorted (self.loop_body (head)))
317		node_data = [(self.nodes[n2], sorted (self.preds[n]),
318				sorted (var_deps[n2].keys ()))
319			for n2 in loop_sort]
320		k = (n, loop_sort)
321		data = (node_data, n)
322		if k in self.loop_var_analysis_cache:
323			for (data2, va) in self.loop_var_analysis_cache[k]:
324				if data2 == data:
325					return va
326		va = logic.compute_loop_var_analysis (self, var_deps, n)
327		group = self.loop_var_analysis_cache.setdefault (k, [])
328		group.append ((data, va))
329		del group[:-10]
330		return va
331
332	def save_graph (self, fname):
333		cols = mk_graph_cols (self.node_tags)
334		save_graph (self.nodes, fname, cols = cols,
335			node_tags = self.node_tags)
336
337	def save_graph_summ (self, fname):
338		node_ids = {}
339		def is_triv (n):
340			if n not in self.nodes:
341				return False
342			if len (self.preds[n]) != 1:
343				return False
344			node = self.nodes[n]
345			if node.kind == 'Basic':
346				return (True, node.cont)
347			elif node.kind == 'Cond' and node.right == 'Err':
348				return (True, node.left)
349			else:
350				return False
351		for n in self.nodes:
352			if n in node_ids:
353				continue
354			ns = []
355			while is_triv (n):
356				ns.append (n)
357				n = is_triv (n)[1]
358			for n2 in ns:
359				node_ids[n2] = n
360		nodes = {}
361		for n in self.nodes:
362			if is_triv (n):
363				continue
364			nodes[n] = syntax.copy_rename (self.nodes[n],
365				({}, node_ids))
366		cols = mk_graph_cols (self.node_tags)
367		save_graph (nodes, fname, cols = cols,
368			node_tags = self.node_tags)
369
370	def serialise (self):
371		ss = ['Problem']
372		for (n, tag, fname, inputs) in self.entries:
373			xs = ['Entry', '%d' % n, tag, fname,
374				'%d' % len (inputs)]
375			for (nm, typ) in inputs:
376				xs.append (nm)
377				typ.serialise (xs)
378			xs.append ('%d' % len (self.outputs[tag]))
379			for (nm, typ) in self.outputs[tag]:
380				xs.append (nm)
381				typ.serialise (xs)
382			ss.append (' '.join (xs))
383		for n in self.nodes:
384			xs = ['%d' % n]
385			self.nodes[n].serialise (xs)
386			ss.append (' '.join (xs))
387		ss.append ('EndProblem')
388		return ss
389
390	def save_serialise (self, fname):
391		ss = self.serialise ()
392		f = open (fname, 'w')
393		for s in ss:
394			f.write (s + '\n')
395		f.close ()
396
397	def pad_merge_points (self):
398		self.compute_preds ()
399
400		arcs = [(pred, n) for n in self.preds
401			if len (self.preds[n]) > 1
402			if n in self.nodes
403			for pred in self.preds[n]
404			if (self.nodes[pred].kind != 'Basic'
405				or self.nodes[pred].upds != [])]
406
407		for (pred, n) in arcs:
408			(tag, _) = self.node_tags[pred]
409			name = self.alloc_node (tag, 'MergePadding')
410			self.nodes[name] = Node ('Basic', n, [])
411			self.nodes[pred] = syntax.copy_rename (self.nodes[pred],
412				({}, {n: name}))
413
414	def function_call_addrs (self):
415		return [(n, self.nodes[n].fname)
416			for n in self.nodes if self.nodes[n].kind == 'Call']
417
418	def function_calls (self):
419		return set ([fn for (n, fn) in self.function_call_addrs ()])
420
421	def get_extensions (self):
422		if 'extensions' in self.cached_analysis:
423			return self.cached_analysis['extensions']
424		extensions = set ()
425		for node in self.nodes.itervalues ():
426			extensions.update (syntax.get_extensions (node))
427		self.cached_analysis['extensions'] = extensions
428		return extensions
429
430	def replay_inline_script (self, tag, script):
431		for (detail, idx, fname) in script:
432			n = self.node_tag_revs[(tag, detail)][idx]
433			assert self.nodes[n].kind == 'Call', self.nodes[n]
434			assert self.nodes[n].fname == fname, self.nodes[n]
435			inline_at_point (self, n, do_analysis = False)
436		if script:
437			self.do_analysis ()
438
439	def is_reachable_from (self, source, target):
440		'''discover if graph addr "target" is reachable
441			from starting node "source"'''
442		k = ('is_reachable_from', source)
443		if k in self.cached_analysis:
444			reachable = self.cached_analysis[k]
445			if target in reachable:
446				return reachable[target]
447
448		reachable = {}
449		visit = [source]
450		while visit:
451			n = visit.pop ()
452			if n not in self.nodes:
453				continue
454			for n2 in self.nodes[n].get_conts ():
455				if n2 not in reachable:
456					reachable[n2] = True
457					visit.append (n2)
458		for n in list (self.nodes) + ['Ret', 'Err']:
459			if n not in reachable:
460				reachable[n] = False
461		self.cached_analysis[k] = reachable
462		return reachable[target]
463
464	def is_reachable_without (self, cutpoint, target):
465		'''discover if graph addr "target" is reachable
466			without visiting node "cutpoint"
467			(an oddity: cutpoint itself is considered reachable)'''
468		k = ('is_reachable_without', cutpoint)
469		if k in self.cached_analysis:
470			reachable = self.cached_analysis[k]
471			if target in reachable:
472				return reachable[target]
473
474		reachable = dict ([(self.get_entry (t), True)
475			for t in self.tags ()])
476		for n in self.tarjan_order + ['Ret', 'Err']:
477			if n in reachable:
478				continue
479			reachable[n] = bool ([pred for pred in self.preds[n]
480				if pred != cutpoint
481				if reachable.get (pred) == True])
482		self.cached_analysis[k] = reachable
483		return reachable[target]
484
485def deserialise (name, lines):
486	assert lines[0] == 'Problem', lines[0]
487	assert lines[-1] == 'EndProblem', lines[-1]
488	i = 1
489	# not easy to reconstruct pairing
490	p = Problem (pairing = None, name = name)
491	while lines[i].startswith ('Entry'):
492		bits = lines[i].split ()
493		en = int (bits[1])
494		tag = bits[2]
495		fname = bits[3]
496		(n, inputs) = syntax.parse_list (syntax.parse_lval, bits, 4)
497		(n, outputs) = syntax.parse_list (syntax.parse_lval, bits, n)
498		assert n == len (bits), (n, bits)
499		p.entries.append ((en, tag, fname, inputs))
500		p.outputs[tag] = outputs
501		i += 1
502	for i in range (i, len (lines) - 1):
503		bits = lines[i].split ()
504		n = int (bits[0])
505		node = syntax.parse_node (bits, 1)
506		p.nodes[n] = node
507	return p
508
509# trivia
510
511def check_no_symbols (nodes):
512	import pseudo_compile
513	symbs = pseudo_compile.nodes_symbols (nodes)
514	if not symbs:
515		return
516	printout ('Aborting %s: undefined symbols %s' % (self.name, symbs))
517	raise Abort ()
518
519# printing of problem graphs
520
521def sanitise_str (s):
522	return s.replace ('"', '_').replace ("'", "_").replace (' ', '')
523
524def graph_name (nodes, node_tags, n, prev=None):
525	if type (n) == str:
526		return 't_%s_%d' % (n, prev)
527	if n not in nodes:
528		return 'unknown_%d' % n
529	if n not in node_tags:
530		ident = '%d' % n
531	else:
532		(tag, details) = node_tags[n]
533		if len (details) > 1 and logic.is_int (details[1]):
534			ident = '%d_%s_%s_0x%x' % (n, tag,
535				details[0], details[1])
536		elif type (details) != str:
537			details = '_'.join (map (str, details))
538			ident = '%d_%s_%s' % (n, tag, details)
539		else:
540			ident = '%d_%s_%s' % (n, tag, details)
541	ident = sanitise_str (ident)
542	node = nodes[n]
543	if node.kind == 'Call':
544		return 'fcall_%s' % ident
545	if node.kind == 'Cond':
546		return ident
547	if node.kind == 'Basic':
548		return 'ass_%s' % ident
549	assert not 'node kind understood'
550
551def graph_node_tooltip (nodes, n):
552	if n == 'Err':
553		return 'Error point'
554	if n == 'Ret':
555		return 'Return point'
556	node = nodes[n]
557	if node.kind == 'Call':
558		return "%s: call to '%s'" % (n, sanitise_str (node.fname))
559	if node.kind == 'Cond':
560		return '%s: conditional node' % n
561	if node.kind == 'Basic':
562		var_names = [sanitise_str (x[0][0]) for x in node.upds]
563		return '%s: assignment to [%s]' % (n, ', '.join (var_names))
564	assert not 'node kind understood'
565
566def graph_edges (nodes, n):
567	node = nodes[n]
568	if node.is_noop ():
569		return [(node.get_conts () [0], 'N')]
570	elif node.kind == 'Cond':
571		return [(node.left, 'T'), (node.right, 'F')]
572	else:
573		return [(node.cont, 'C')]
574
575def get_graph_font (n, col):
576	font = 'fontname = "Arial", fontsize = 20, penwidth=3'
577	if col:
578		font = font + ', color=%s, fontcolor=%s' % (col, col)
579	return font
580
581def get_graph_loops (nodes):
582	graph = dict ([(n, [c for c in nodes[n].get_conts ()
583		if type (c) != str]) for n in nodes])
584	graph['ENTRY'] = list (nodes)
585	comps = logic.tarjan (graph, ['ENTRY'])
586	comp_ids = {}
587	for (head, tail) in comps:
588		comp_ids[head] = head
589		for n in tail:
590			comp_ids[n] = head
591	loops = set ([(n, n2) for n in graph for n2 in graph[n]
592		if comp_ids[n] == comp_ids[n2]])
593	return loops
594
595def make_graph (nodes, cols, node_tags = {}, entries = []):
596	graph = []
597	graph.append ('digraph foo {')
598
599	loops = get_graph_loops (nodes)
600
601	for n in nodes:
602		n_nm = graph_name (nodes, node_tags, n)
603		f = get_graph_font (n, cols.get (n))
604		graph.append ('%s [%s\n label="%s"\n tooltip="%s"];' % (n,
605			f, n_nm, graph_node_tooltip (nodes, n)))
606		for (c, l) in graph_edges (nodes, n):
607			if c in ['Ret', 'Err']:
608				c_nm = '%s_%s' % (c, n)
609				if c == 'Ret':
610					f2 = f + ', shape=doubleoctagon'
611				else:
612					f2 = f + ', shape=Mdiamond'
613				graph.append ('%s [label="%s", %s];'
614					% (c_nm, c, f2))
615			else:
616				c_nm = c
617			ft = f
618			if (n, c) in loops:
619				ft = f + ', penwidth=6'
620			graph.append ('%s -> %s [label=%s, %s];' % (
621				n, c_nm, l, ft))
622
623	for (i, (n, tag, inps)) in enumerate (entries):
624		f = get_graph_font (n, cols.get (n))
625		nm1 = tag + ' ENTRY_POINT'
626		nm2 = 'entry_point_%d' % i
627		graph.extend (['%s -> %s [%s];' % (nm2, n, f),
628			'%s [label = "%s", shape=none, %s];' % (nm2, nm1, f)])
629
630	graph.append ('}')
631	return graph
632
633def print_graph (nodes, cols = {}, entries = []):
634	for line in make_graph (nodes, cols, entries):
635		print line
636
637def save_graph (nodes, fname, cols = {}, entries = [], node_tags = {}):
638	f = open (fname, 'w')
639	for line in make_graph (nodes, cols = cols, node_tags = node_tags,
640			entries = entries):
641		f.write (line + '\n')
642	f.close ()
643
644def mk_graph_cols (node_tags):
645	known_cols = {'C': "forestgreen", 'ASM_adj': "darkblue",
646		'ASM': "darkorange"}
647	cols = {}
648	for n in node_tags:
649		if node_tags[n][0] in known_cols:
650			cols[n] = known_cols[node_tags[n][0]]
651	return cols
652
653def make_graph_with_eqs (p, invis = False):
654	if invis:
655		invis_s = ', style=invis'
656	else:
657		invis_s = ''
658	cols = mk_graph_cols (p.node_tags)
659	graph = make_graph (p.nodes, cols = cols)
660	graph.pop ()
661	for k in p.known_eqs:
662		if k == 'Hyps':
663			continue
664		(n_vc_x, tag_x) = k
665		nm1 = graph_name (p.nodes, p.node_tags, n_vc_x[0])
666		for (x, n_vc_y, tag_y, y, hyps) in p.known_eqs[k]:
667			nm2 = graph_name (p.nodes, p.node_tags, n_vc_y[0])
668			graph.extend ([('%s -> %s [ dir = back, color = blue, '
669				'penwidth = 3, weight = 0 %s ]')
670					% (nm2, nm1, invis_s)])
671	graph.append ('}')
672	return graph
673
674def save_graph_with_eqs (p, fname = 'diagram.dot', invis = False):
675	graph = make_graph_with_eqs (p, invis = invis)
676	f = open (fname, 'w')
677	for s in graph:
678		f.write (s + '\n')
679	f.close ()
680
681def get_problem_vars (p):
682	inout = set.union (* ([set(xs) for xs in p.outputs.itervalues ()]
683		+ [set (args) for (_, _, _, args) in p.entries]))
684
685	vs = dict(inout)
686	for node in p.nodes.itervalues():
687		syntax.get_node_vars(node, vs)
688	return vs
689
690def is_trivial_fun (fun):
691	for node in fun.nodes.itervalues ():
692		if node.is_noop ():
693			continue
694		if node.kind == 'Call':
695			return False
696		elif node.kind == 'Basic':
697			for (lv, v) in node.upds:
698				if v.kind not in ['Var', 'Num']:
699					return False
700		elif node.kind == 'Cond':
701			if node.cond.kind != 'Var' and node.cond not in [
702					true_term, false_term]:
703				return False
704	return True
705
706last_alt_nodes = [0]
707
708def avail_val (vs, typ):
709	for (nm, typ2) in vs:
710		if typ2 == typ:
711			return mk_var (nm, typ2)
712	return logic.default_val (typ)
713
714def inline_at_point (p, n, do_analysis = True):
715	node = p.nodes[n]
716	if node.kind != 'Call':
717		return
718
719	f_nm = node.fname
720	fun = functions[f_nm]
721	(tag, detail) = p.node_tags[n]
722	idx = p.node_tag_revs[(tag, detail)].index (n)
723	p.inline_scripts[tag].append ((detail, idx, f_nm))
724
725	trace ('Inlining %s into %s' % (f_nm, p.name))
726	if n in p.loop_data:
727		trace ('  inlining into loop %d!' % p.loop_id (n))
728
729	ex = p.alloc_node (tag, (f_nm, 'RetToCaller'))
730
731	(ns, vs) = p.add_function (fun, tag, {'Ret': ex})
732	en = ns[fun.entry]
733
734	inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs]
735	p.nodes[n] = Node ('Basic', en, azip (inp_lvs, node.args))
736
737	out_vs = [mk_var (vs[v], typ) for (v, typ) in fun.outputs]
738	p.nodes[ex] = Node ('Basic', node.cont, azip (node.rets, out_vs))
739
740	p.cached_analysis.clear ()
741
742	if do_analysis:
743		p.do_analysis ()
744
745	trace ('Problem size now %d' % len(p.nodes))
746	sys.stdin.flush ()
747
748	return ns.values ()
749
750def loop_body_inner_loops (p, head, loop_body):
751	loop_set_all = set (loop_body)
752	loop_set = loop_set_all - set ([head])
753	graph = dict([(n, [c for c in p.nodes[n].get_conts ()
754			if c in loop_set])
755		for n in loop_set_all])
756
757	comps = logic.tarjan (graph, [head])
758	assert sum ([1 + len (t) for (_, t) in comps]) == len (loop_set_all)
759	return [comp for comp in comps if comp[1]]
760
761def loop_inner_loops (p, head):
762	k = ('inner_loop_set', head)
763	if k in p.cached_analysis:
764		return p.cached_analysis[k]
765	res = loop_body_inner_loops (p, head, p.loop_body (head))
766	p.cached_analysis[k] = res
767	return res
768
769def loop_heads_including_inner (p):
770	heads = p.loop_heads ()
771	check = [(head, p.loop_body (head)) for head in heads]
772	while check:
773		(head, body) = check.pop ()
774		comps = loop_body_inner_loops (p, head, body)
775		heads.extend ([head for (head, _) in comps])
776		check.extend ([(head, [head] + list (body))
777			for (head, body) in comps])
778	return heads
779
780def check_no_inner_loop (p, head):
781	subs = loop_inner_loops (p, head)
782	if subs:
783		printout ('Aborting %s, complex loop' % p.name)
784		trace ('  sub-loops %s of loop at %s' % (subs, head))
785		for (h, _) in subs:
786			trace ('    head %d tagged %s' % (h, p.node_tags[h]))
787		raise Abort ()
788
789def has_inner_loop (p, head):
790	return bool (loop_inner_loops (p, head))
791
792def fun_has_inner_loop (f):
793	p = f.as_problem (Problem)
794	p.do_analysis ()
795	return bool ([head for head in p.loop_heads ()
796		if has_inner_loop (p, head)])
797
798def loop_var_analysis (p, head, tail):
799	# getting the set of variables that go round the loop
800	nodes = set (tail)
801	nodes.add (head)
802	used_vs = set ([])
803	created_vs_at = {}
804	visit = []
805
806	def process_node (n, created):
807		if p.nodes[n].is_noop ():
808			lvals = set ([])
809		else:
810			vs = syntax.get_node_rvals (p.nodes[n])
811			for rv in vs.iteritems ():
812				if rv not in created:
813					used_vs.add (rv)
814			lvals = set (p.nodes[n].get_lvals ())
815
816		created = set.union (created, lvals)
817		created_vs_at[n] = created
818
819		visit.extend (p.nodes[n].get_conts ())
820
821	process_node (head, set ([]))
822
823	while visit:
824		n = visit.pop ()
825		if (n not in nodes) or (n in created_vs_at):
826			continue
827		if not all ([pr in created_vs_at for pr in p.preds[n]]):
828			continue
829
830		pre_created = [created_vs_at[pr] for pr in p.preds[n]]
831		process_node (n, set.union (* pre_created))
832
833	final_pre_created = [created_vs_at[pr] for pr in p.preds[head]
834		if pr in nodes]
835	created = set.union (* final_pre_created)
836
837	loop_vs = set.intersection (created, used_vs)
838	trace ('Loop vars at head: %s' % loop_vs)
839
840	return loop_vs
841
842
843