1/*
2 * Copyright 2008-2009 Katholieke Universiteit Leuven
3 *
4 * Use of this software is governed by the MIT license
5 *
6 * Written by Sven Verdoolaege, K.U.Leuven, Departement
7 * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8 */
9
10#include <isl_ctx_private.h>
11#include <isl_map_private.h>
12#include <isl/set.h>
13#include <isl/map.h>
14#include <isl/mat.h>
15#include <isl/seq.h>
16#include "isl_piplib.h"
17#include "isl_map_piplib.h"
18
19static void copy_values_from(isl_int *dst, Entier *src, unsigned n)
20{
21	int i;
22
23	for (i = 0; i < n; ++i)
24		entier_assign(dst[i], src[i]);
25}
26
27static void add_value(isl_int *dst, Entier *src)
28{
29	mpz_add(*dst, *dst, *src);
30}
31
32static void copy_constraint_from(isl_int *dst, PipVector *src,
33		unsigned nparam, unsigned n_in, unsigned n_out,
34		unsigned extra, int *pos)
35{
36	int i;
37
38	copy_values_from(dst, src->the_vector+src->nb_elements-1, 1);
39	copy_values_from(dst+1, src->the_vector, nparam+n_in);
40	isl_seq_clr(dst+1+nparam+n_in, n_out);
41	isl_seq_clr(dst+1+nparam+n_in+n_out, extra);
42	for (i = 0; i + n_in + nparam < src->nb_elements-1; ++i) {
43		int p = pos[i];
44		add_value(&dst[1+nparam+n_in+n_out+p],
45			  &src->the_vector[n_in+nparam+i]);
46	}
47}
48
49static int add_inequality(struct isl_ctx *ctx,
50		   struct isl_basic_map *bmap, int *pos, PipVector *vec)
51{
52	unsigned nparam = isl_basic_map_n_param(bmap);
53	unsigned n_in = isl_basic_map_n_in(bmap);
54	unsigned n_out = isl_basic_map_n_out(bmap);
55	unsigned n_div = isl_basic_map_n_div(bmap);
56	int i = isl_basic_map_alloc_inequality(bmap);
57	if (i < 0)
58		return -1;
59	copy_constraint_from(bmap->ineq[i], vec,
60	    nparam, n_in, n_out, n_div, pos);
61
62	return i;
63}
64
65/* For a div d = floor(f/m), add the constraints
66 *
67 *		f - m d >= 0
68 *		-(f-(n-1)) + m d >= 0
69 *
70 * Note that the second constraint is the negation of
71 *
72 *		f - m d >= n
73 */
74static int add_div_constraints(struct isl_ctx *ctx,
75	struct isl_basic_map *bmap, int *pos, PipNewparm *p, unsigned div)
76{
77	int i, j;
78	unsigned total = isl_basic_map_total_dim(bmap);
79	unsigned div_pos = 1 + total - bmap->n_div + div;
80
81	i = add_inequality(ctx, bmap, pos, p->vector);
82	if (i < 0)
83		return -1;
84	copy_values_from(&bmap->ineq[i][div_pos], &p->deno, 1);
85	isl_int_neg(bmap->ineq[i][div_pos], bmap->ineq[i][div_pos]);
86
87	j = isl_basic_map_alloc_inequality(bmap);
88	if (j < 0)
89		return -1;
90	isl_seq_neg(bmap->ineq[j], bmap->ineq[i], 1 + total);
91	isl_int_add(bmap->ineq[j][0], bmap->ineq[j][0], bmap->ineq[j][div_pos]);
92	isl_int_sub_ui(bmap->ineq[j][0], bmap->ineq[j][0], 1);
93	return j;
94}
95
96static int add_equality(struct isl_ctx *ctx,
97		   struct isl_basic_map *bmap, int *pos,
98		   unsigned var, PipVector *vec)
99{
100	int i;
101	unsigned nparam = isl_basic_map_n_param(bmap);
102	unsigned n_in = isl_basic_map_n_in(bmap);
103	unsigned n_out = isl_basic_map_n_out(bmap);
104
105	isl_assert(ctx, var < n_out, return -1);
106
107	i = isl_basic_map_alloc_equality(bmap);
108	if (i < 0)
109		return -1;
110	copy_constraint_from(bmap->eq[i], vec,
111	    nparam, n_in, n_out, bmap->extra, pos);
112	isl_int_set_si(bmap->eq[i][1+nparam+n_in+var], -1);
113
114	return i;
115}
116
117static int find_div(struct isl_ctx *ctx,
118		   struct isl_basic_map *bmap, int *pos, PipNewparm *p)
119{
120	int i, j;
121	unsigned nparam = isl_basic_map_n_param(bmap);
122	unsigned n_in = isl_basic_map_n_in(bmap);
123	unsigned n_out = isl_basic_map_n_out(bmap);
124
125	i = isl_basic_map_alloc_div(bmap);
126	if (i < 0)
127		return -1;
128
129	copy_constraint_from(bmap->div[i]+1, p->vector,
130	    nparam, n_in, n_out, bmap->extra, pos);
131
132	copy_values_from(bmap->div[i], &p->deno, 1);
133	for (j = 0; j < i; ++j)
134		if (isl_seq_eq(bmap->div[i], bmap->div[j],
135				1+1+isl_basic_map_total_dim(bmap)+j)) {
136			isl_basic_map_free_div(bmap, 1);
137			return j;
138		}
139
140	if (add_div_constraints(ctx, bmap, pos, p, i) < 0)
141		return -1;
142
143	return i;
144}
145
146/* Count some properties of a quast
147 * - maximal number of new parameters
148 * - maximal depth
149 * - total number of solutions
150 * - total number of empty branches
151 */
152static void quast_count(PipQuast *q, int *maxnew, int depth, int *maxdepth,
153		        int *sol, int *nosol)
154{
155	PipNewparm *p;
156
157	for (p = q->newparm; p; p = p->next)
158		if (p->rank > *maxnew)
159			*maxnew = p->rank;
160	if (q->condition) {
161		if (++depth > *maxdepth)
162			*maxdepth = depth;
163		quast_count(q->next_else, maxnew, depth, maxdepth, sol, nosol);
164		quast_count(q->next_then, maxnew, depth, maxdepth, sol, nosol);
165	} else {
166		if (q->list)
167			++(*sol);
168		else
169			++(*nosol);
170	}
171}
172
173/*
174 * pos: array of length bmap->set.extra, mapping each of the existential
175 *		variables PIP proposes to an existential variable in bmap
176 * bmap: collects the currently active constraints
177 * rest: collects the empty leaves of the quast (if not NULL)
178 */
179struct scan_data {
180	struct isl_ctx 			*ctx;
181	struct isl_basic_map 		*bmap;
182	struct isl_set			**rest;
183	int	   *pos;
184};
185
186/*
187 * New existentially quantified variables are places after the existing ones.
188 */
189static struct isl_map *scan_quast_r(struct scan_data *data, PipQuast *q,
190				    struct isl_map *map)
191{
192	PipNewparm *p;
193	struct isl_basic_map *bmap = data->bmap;
194	unsigned old_n_div = bmap->n_div;
195	unsigned nparam = isl_basic_map_n_param(bmap);
196	unsigned n_in = isl_basic_map_n_in(bmap);
197	unsigned n_out = isl_basic_map_n_out(bmap);
198
199	if (!map)
200		goto error;
201
202	for (p = q->newparm; p; p = p->next) {
203		int pos;
204		unsigned pip_param = nparam + n_in;
205
206		pos = find_div(data->ctx, bmap, data->pos, p);
207		if (pos < 0)
208			goto error;
209		data->pos[p->rank - pip_param] = pos;
210	}
211
212	if (q->condition) {
213		int pos = add_inequality(data->ctx, bmap, data->pos,
214					 q->condition);
215		if (pos < 0)
216			goto error;
217		map = scan_quast_r(data, q->next_then, map);
218
219		if (isl_inequality_negate(bmap, pos))
220			goto error;
221		map = scan_quast_r(data, q->next_else, map);
222
223		if (isl_basic_map_free_inequality(bmap, 1))
224			goto error;
225	} else if (q->list) {
226		PipList *l;
227		int j;
228		/* if bmap->n_out is zero, we are only interested in the domains
229		 * where a solution exists and not in the actual solution
230		 */
231		for (j = 0, l = q->list; j < n_out && l; ++j, l = l->next)
232			if (add_equality(data->ctx, bmap, data->pos, j,
233						l->vector) < 0)
234				goto error;
235		map = isl_map_add_basic_map(map, isl_basic_map_copy(bmap));
236		if (isl_basic_map_free_equality(bmap, n_out))
237			goto error;
238	} else if (data->rest) {
239		struct isl_basic_set *bset;
240		bset = isl_basic_set_from_basic_map(isl_basic_map_copy(bmap));
241		bset = isl_basic_set_drop_dims(bset, n_in, n_out);
242		if (!bset)
243			goto error;
244		*data->rest = isl_set_add_basic_set(*data->rest, bset);
245	}
246
247	if (isl_basic_map_free_inequality(bmap, 2*(bmap->n_div - old_n_div)))
248		goto error;
249	if (isl_basic_map_free_div(bmap, bmap->n_div - old_n_div))
250		goto error;
251	return map;
252error:
253	isl_map_free(map);
254	return NULL;
255}
256
257/*
258 * Returns a map of dimension "keep_dim" with "context" as domain and
259 * as range the first "isl_space_dim(keep_dim, isl_dim_out)" variables
260 * in the quast lists.
261 */
262static struct isl_map *isl_map_from_quast(struct isl_ctx *ctx, PipQuast *q,
263		isl_space *keep_dim,
264		struct isl_basic_set *context,
265		struct isl_set **rest)
266{
267	int		pip_param;
268	int		nexist;
269	int		max_depth;
270	int		n_sol, n_nosol;
271	struct scan_data	data;
272	struct isl_map		*map = NULL;
273	isl_space		*dims;
274	unsigned		nparam;
275	unsigned		dim;
276	unsigned		keep;
277
278	data.ctx = ctx;
279	data.rest = rest;
280	data.bmap = NULL;
281	data.pos = NULL;
282
283	if (!context || !keep_dim)
284		goto error;
285
286	dim = isl_basic_set_n_dim(context);
287	nparam = isl_basic_set_n_param(context);
288	keep = isl_space_dim(keep_dim, isl_dim_out);
289	pip_param = nparam + dim;
290
291	max_depth = 0;
292	n_sol = 0;
293	n_nosol = 0;
294	nexist = pip_param-1;
295	quast_count(q, &nexist, 0, &max_depth, &n_sol, &n_nosol);
296	nexist -= pip_param-1;
297
298	if (rest) {
299		*rest = isl_set_alloc_space(isl_space_copy(context->dim), n_nosol,
300					ISL_MAP_DISJOINT);
301		if (!*rest)
302			goto error;
303	}
304	map = isl_map_alloc_space(isl_space_copy(keep_dim), n_sol,
305				ISL_MAP_DISJOINT);
306	if (!map)
307		goto error;
308
309	dims = isl_space_reverse(isl_space_copy(context->dim));
310	data.bmap = isl_basic_map_from_basic_set(context, dims);
311	data.bmap = isl_basic_map_extend_space(data.bmap,
312		keep_dim, nexist, keep, max_depth+2*nexist);
313	if (!data.bmap)
314		goto error2;
315
316	if (data.bmap->extra) {
317		int i;
318		data.pos = isl_alloc_array(ctx, int, data.bmap->extra);
319		if (!data.pos)
320			goto error;
321		for (i = 0; i < data.bmap->n_div; ++i)
322			data.pos[i] = i;
323	}
324
325	map = scan_quast_r(&data, q, map);
326	map = isl_map_finalize(map);
327	if (!map)
328		goto error2;
329	if (rest) {
330		*rest = isl_set_finalize(*rest);
331		if (!*rest)
332			goto error2;
333	}
334	isl_basic_map_free(data.bmap);
335	if (data.pos)
336		free(data.pos);
337	return map;
338error:
339	isl_basic_set_free(context);
340	isl_space_free(keep_dim);
341error2:
342	if (data.pos)
343		free(data.pos);
344	isl_basic_map_free(data.bmap);
345	isl_map_free(map);
346	if (rest) {
347		isl_set_free(*rest);
348		*rest = NULL;
349	}
350	return NULL;
351}
352
353static void copy_values_to(Entier *dst, isl_int *src, unsigned n)
354{
355	int i;
356
357	for (i = 0; i < n; ++i)
358		entier_assign(dst[i], src[i]);
359}
360
361static void copy_constraint_to(Entier *dst, isl_int *src,
362		unsigned pip_param, unsigned pip_var,
363		unsigned extra_front, unsigned extra_back)
364{
365	copy_values_to(dst+1+extra_front+pip_var+pip_param+extra_back, src, 1);
366	copy_values_to(dst+1+extra_front+pip_var, src+1, pip_param);
367	copy_values_to(dst+1+extra_front, src+1+pip_param, pip_var);
368}
369
370PipMatrix *isl_basic_map_to_pip(struct isl_basic_map *bmap, unsigned pip_param,
371			 unsigned extra_front, unsigned extra_back)
372{
373	int i;
374	unsigned nrow;
375	unsigned ncol;
376	PipMatrix *M;
377	unsigned off;
378	unsigned pip_var = isl_basic_map_total_dim(bmap) - pip_param;
379
380	nrow = extra_front + bmap->n_eq + bmap->n_ineq;
381	ncol = 1 + extra_front + pip_var + pip_param + extra_back + 1;
382	M = pip_matrix_alloc(nrow, ncol);
383	if (!M)
384		return NULL;
385
386	off = extra_front;
387	for (i = 0; i < bmap->n_eq; ++i) {
388		entier_set_si(M->p[off+i][0], 0);
389		copy_constraint_to(M->p[off+i], bmap->eq[i],
390				   pip_param, pip_var, extra_front, extra_back);
391	}
392	off += bmap->n_eq;
393	for (i = 0; i < bmap->n_ineq; ++i) {
394		entier_set_si(M->p[off+i][0], 1);
395		copy_constraint_to(M->p[off+i], bmap->ineq[i],
396				   pip_param, pip_var, extra_front, extra_back);
397	}
398	return M;
399}
400
401PipMatrix *isl_basic_set_to_pip(struct isl_basic_set *bset, unsigned pip_param,
402			 unsigned extra_front, unsigned extra_back)
403{
404	return isl_basic_map_to_pip((struct isl_basic_map *)bset,
405					pip_param, extra_front, extra_back);
406}
407
408struct isl_map *isl_pip_basic_map_lexopt(
409		struct isl_basic_map *bmap, struct isl_basic_set *dom,
410		struct isl_set **empty, int max)
411{
412	PipOptions	*options;
413	PipQuast	*sol;
414	struct isl_map	*map;
415	struct isl_ctx  *ctx;
416	PipMatrix *domain = NULL, *context = NULL;
417	unsigned	 nparam, n_in, n_out;
418
419	bmap = isl_basic_map_detect_equalities(bmap);
420	if (!bmap || !dom)
421		goto error;
422
423	ctx = bmap->ctx;
424	isl_assert(ctx, isl_basic_map_compatible_domain(bmap, dom), goto error);
425	nparam = isl_basic_map_n_param(bmap);
426	n_in = isl_basic_map_n_in(bmap);
427	n_out = isl_basic_map_n_out(bmap);
428
429	domain = isl_basic_map_to_pip(bmap, nparam + n_in, 0, dom->n_div);
430	if (!domain)
431		goto error;
432	context = isl_basic_map_to_pip((struct isl_basic_map *)dom, 0, 0, 0);
433	if (!context)
434		goto error;
435
436	options = pip_options_init();
437	options->Simplify = 1;
438	options->Maximize = max;
439	options->Urs_unknowns = -1;
440	options->Urs_parms = -1;
441	sol = pip_solve(domain, context, -1, options);
442
443	if (sol) {
444		struct isl_basic_set *copy;
445		copy = isl_basic_set_copy(dom);
446		map = isl_map_from_quast(ctx, sol,
447				isl_space_copy(bmap->dim), copy, empty);
448	} else {
449		map = isl_map_empty_like_basic_map(bmap);
450		if (empty)
451			*empty = NULL;
452	}
453	if (!map)
454		goto error;
455	if (map->n == 0 && empty) {
456		isl_set_free(*empty);
457		*empty = isl_set_from_basic_set(dom);
458	} else
459		isl_basic_set_free(dom);
460	isl_basic_map_free(bmap);
461
462	pip_quast_free(sol);
463	pip_options_free(options);
464	pip_matrix_free(domain);
465	pip_matrix_free(context);
466
467	return map;
468error:
469	if (domain)
470		pip_matrix_free(domain);
471	if (context)
472		pip_matrix_free(context);
473	isl_basic_map_free(bmap);
474	isl_basic_set_free(dom);
475	return NULL;
476}
477