1/*
2 * Copyright 2017-2021 The OpenSSL Project Authors. All Rights Reserved.
3 * Copyright (c) 2017, Oracle and/or its affiliates.  All rights reserved.
4 *
5 * Licensed under the Apache License 2.0 (the "License").  You may not use
6 * this file except in compliance with the License.  You can obtain a copy
7 * in the file LICENSE in the source distribution or at
8 * https://www.openssl.org/source/license.html
9 */
10
11#include <stdio.h>
12#include <string.h>
13
14#include <openssl/opensslconf.h>
15#include <openssl/safestack.h>
16#include <openssl/err.h>
17#include <openssl/crypto.h>
18
19#include "internal/nelem.h"
20#include "testutil.h"
21
22/* The macros below generate unused functions which error out one of the clang
23 * builds.  We disable this check here.
24 */
25#ifdef __clang__
26#pragma clang diagnostic ignored "-Wunused-function"
27#endif
28
29typedef struct {
30    int n;
31    char c;
32} SS;
33
34typedef union {
35    int n;
36    char c;
37} SU;
38
39DEFINE_SPECIAL_STACK_OF(sint, int)
40DEFINE_SPECIAL_STACK_OF_CONST(uchar, unsigned char)
41DEFINE_STACK_OF(SS)
42DEFINE_STACK_OF_CONST(SU)
43
44static int int_compare(const int *const *a, const int *const *b)
45{
46    if (**a < **b)
47        return -1;
48    if (**a > **b)
49        return 1;
50    return 0;
51}
52
53static int test_int_stack(int reserve)
54{
55    static int v[] = { 1, 2, -4, 16, 999, 1, -173, 1, 9 };
56    static int notpresent = -1;
57    const int n = OSSL_NELEM(v);
58    static struct {
59        int value;
60        int unsorted;
61        int sorted;
62        int ex;
63    } finds[] = {
64        { 2,    1,  5,  5   },
65        { 9,    7,  6,  6   },
66        { -173, 5,  0,  0   },
67        { 999,  3,  8,  8   },
68        { 0,   -1, -1,  1   }
69    };
70    const int n_finds = OSSL_NELEM(finds);
71    static struct {
72        int value;
73        int ex;
74    } exfinds[] = {
75        { 3,    5   },
76        { 1000, 8   },
77        { 20,   8   },
78        { -999, 0   },
79        { -5,   0   },
80        { 8,    5   }
81    };
82    const int n_exfinds = OSSL_NELEM(exfinds);
83    STACK_OF(sint) *s = sk_sint_new_null();
84    int i;
85    int testresult = 0;
86
87    if (!TEST_ptr(s)
88        || (reserve > 0 && !TEST_true(sk_sint_reserve(s, 5 * reserve))))
89        goto end;
90
91    /* Check push and num */
92    for (i = 0; i < n; i++) {
93        if (!TEST_int_eq(sk_sint_num(s), i)) {
94            TEST_info("int stack size %d", i);
95            goto end;
96        }
97        sk_sint_push(s, v + i);
98    }
99    if (!TEST_int_eq(sk_sint_num(s), n))
100        goto end;
101
102    /* check the values */
103    for (i = 0; i < n; i++)
104        if (!TEST_ptr_eq(sk_sint_value(s, i), v + i)) {
105            TEST_info("int value %d", i);
106            goto end;
107        }
108
109    /* find unsorted -- the pointers are compared */
110    for (i = 0; i < n_finds; i++) {
111        int *val = (finds[i].unsorted == -1) ? &notpresent
112                                             : v + finds[i].unsorted;
113
114        if (!TEST_int_eq(sk_sint_find(s, val), finds[i].unsorted)) {
115            TEST_info("int unsorted find %d", i);
116            goto end;
117        }
118    }
119
120    /* find_ex unsorted */
121    for (i = 0; i < n_finds; i++) {
122        int *val = (finds[i].unsorted == -1) ? &notpresent
123                                             : v + finds[i].unsorted;
124
125        if (!TEST_int_eq(sk_sint_find_ex(s, val), finds[i].unsorted)) {
126            TEST_info("int unsorted find_ex %d", i);
127            goto end;
128        }
129    }
130
131    /* sorting */
132    if (!TEST_false(sk_sint_is_sorted(s)))
133        goto end;
134    (void)sk_sint_set_cmp_func(s, &int_compare);
135    sk_sint_sort(s);
136    if (!TEST_true(sk_sint_is_sorted(s)))
137        goto end;
138
139    /* find sorted -- the value is matched so we don't need to locate it */
140    for (i = 0; i < n_finds; i++)
141        if (!TEST_int_eq(sk_sint_find(s, &finds[i].value), finds[i].sorted)) {
142            TEST_info("int sorted find %d", i);
143            goto end;
144        }
145
146    /* find_ex sorted */
147    for (i = 0; i < n_finds; i++)
148        if (!TEST_int_eq(sk_sint_find_ex(s, &finds[i].value), finds[i].ex)) {
149            TEST_info("int sorted find_ex present %d", i);
150            goto end;
151        }
152    for (i = 0; i < n_exfinds; i++)
153        if (!TEST_int_eq(sk_sint_find_ex(s, &exfinds[i].value), exfinds[i].ex)){
154            TEST_info("int sorted find_ex absent %d", i);
155            goto end;
156        }
157
158    /* shift */
159    if (!TEST_ptr_eq(sk_sint_shift(s), v + 6))
160        goto end;
161
162    testresult = 1;
163end:
164    sk_sint_free(s);
165    return testresult;
166}
167
168static int uchar_compare(const unsigned char *const *a,
169                         const unsigned char *const *b)
170{
171    return **a - (signed int)**b;
172}
173
174static int test_uchar_stack(int reserve)
175{
176    static const unsigned char v[] = { 1, 3, 7, 5, 255, 0 };
177    const int n = OSSL_NELEM(v);
178    STACK_OF(uchar) *s = sk_uchar_new(&uchar_compare), *r = NULL;
179    int i;
180    int testresult = 0;
181
182    if (!TEST_ptr(s)
183        || (reserve > 0 && !TEST_true(sk_uchar_reserve(s, 5 * reserve))))
184        goto end;
185
186    /* unshift and num */
187    for (i = 0; i < n; i++) {
188        if (!TEST_int_eq(sk_uchar_num(s), i)) {
189            TEST_info("uchar stack size %d", i);
190            goto end;
191        }
192        sk_uchar_unshift(s, v + i);
193    }
194    if (!TEST_int_eq(sk_uchar_num(s), n))
195        goto end;
196
197    /* dup */
198    r = sk_uchar_dup(NULL);
199    if (sk_uchar_num(r) != 0)
200        goto end;
201    sk_uchar_free(r);
202    r = sk_uchar_dup(s);
203    if (!TEST_int_eq(sk_uchar_num(r), n))
204        goto end;
205    sk_uchar_sort(r);
206
207    /* pop */
208    for (i = 0; i < n; i++)
209        if (!TEST_ptr_eq(sk_uchar_pop(s), v + i)) {
210            TEST_info("uchar pop %d", i);
211            goto end;
212        }
213
214    /* free -- we rely on the debug malloc to detect leakage here */
215    sk_uchar_free(s);
216    s = NULL;
217
218    /* dup again */
219    if (!TEST_int_eq(sk_uchar_num(r), n))
220        goto end;
221
222    /* zero */
223    sk_uchar_zero(r);
224    if (!TEST_int_eq(sk_uchar_num(r), 0))
225        goto end;
226
227    /* insert */
228    sk_uchar_insert(r, v, 0);
229    sk_uchar_insert(r, v + 2, -1);
230    sk_uchar_insert(r, v + 1, 1);
231    for (i = 0; i < 3; i++)
232        if (!TEST_ptr_eq(sk_uchar_value(r, i), v + i)) {
233            TEST_info("uchar insert %d", i);
234            goto end;
235        }
236
237    /* delete */
238    if (!TEST_ptr_null(sk_uchar_delete(r, 12)))
239        goto end;
240    if (!TEST_ptr_eq(sk_uchar_delete(r, 1), v + 1))
241        goto end;
242
243    /* set */
244    (void)sk_uchar_set(r, 1, v + 1);
245    for (i = 0; i < 2; i++)
246        if (!TEST_ptr_eq(sk_uchar_value(r, i), v + i)) {
247            TEST_info("uchar set %d", i);
248            goto end;
249        }
250
251    testresult = 1;
252end:
253    sk_uchar_free(r);
254    sk_uchar_free(s);
255    return testresult;
256}
257
258static SS *SS_copy(const SS *p)
259{
260    SS *q = OPENSSL_malloc(sizeof(*q));
261
262    if (q != NULL)
263        memcpy(q, p, sizeof(*q));
264    return q;
265}
266
267static void SS_free(SS *p) {
268    OPENSSL_free(p);
269}
270
271static int test_SS_stack(void)
272{
273    STACK_OF(SS) *s = sk_SS_new_null();
274    STACK_OF(SS) *r = NULL;
275    SS *v[10], *p;
276    const int n = OSSL_NELEM(v);
277    int i;
278    int testresult = 0;
279
280    /* allocate and push */
281    for (i = 0; i < n; i++) {
282        v[i] = OPENSSL_malloc(sizeof(*v[i]));
283
284        if (!TEST_ptr(v[i]))
285            goto end;
286        v[i]->n = i;
287        v[i]->c = 'A' + i;
288        if (!TEST_int_eq(sk_SS_num(s), i)) {
289            TEST_info("SS stack size %d", i);
290            goto end;
291        }
292        sk_SS_push(s, v[i]);
293    }
294    if (!TEST_int_eq(sk_SS_num(s), n))
295        goto end;
296
297    /* deepcopy */
298    r = sk_SS_deep_copy(NULL, &SS_copy, &SS_free);
299    if (sk_SS_num(r) != 0)
300        goto end;
301    sk_SS_free(r);
302    r = sk_SS_deep_copy(s, &SS_copy, &SS_free);
303    if (!TEST_ptr(r))
304        goto end;
305    for (i = 0; i < n; i++) {
306        p = sk_SS_value(r, i);
307        if (!TEST_ptr_ne(p, v[i])) {
308            TEST_info("SS deepcopy non-copy %d", i);
309            goto end;
310        }
311        if (!TEST_int_eq(p->n, v[i]->n)) {
312            TEST_info("test SS deepcopy int %d", i);
313            goto end;
314        }
315        if (!TEST_char_eq(p->c, v[i]->c)) {
316            TEST_info("SS deepcopy char %d", i);
317            goto end;
318        }
319    }
320
321    /* pop_free - we rely on the malloc debug to catch the leak */
322    sk_SS_pop_free(r, &SS_free);
323    r = NULL;
324
325    /* delete_ptr */
326    p = sk_SS_delete_ptr(s, v[3]);
327    if (!TEST_ptr(p))
328        goto end;
329    SS_free(p);
330    if (!TEST_int_eq(sk_SS_num(s), n - 1))
331        goto end;
332    for (i = 0; i < n-1; i++)
333        if (!TEST_ptr_eq(sk_SS_value(s, i), v[i<3 ? i : 1+i])) {
334            TEST_info("SS delete ptr item %d", i);
335            goto end;
336        }
337
338    testresult = 1;
339end:
340    sk_SS_pop_free(r, &SS_free);
341    sk_SS_pop_free(s, &SS_free);
342    return testresult;
343}
344
345static int test_SU_stack(void)
346{
347    STACK_OF(SU) *s = sk_SU_new_null();
348    SU v[10];
349    const int n = OSSL_NELEM(v);
350    int i;
351    int testresult = 0;
352
353    /* allocate and push */
354    for (i = 0; i < n; i++) {
355        if ((i & 1) == 0)
356            v[i].n = i;
357        else
358            v[i].c = 'A' + i;
359        if (!TEST_int_eq(sk_SU_num(s), i)) {
360            TEST_info("SU stack size %d", i);
361            goto end;
362        }
363        sk_SU_push(s, v + i);
364    }
365    if (!TEST_int_eq(sk_SU_num(s), n))
366        goto end;
367
368    /* check the pointers are correct */
369    for (i = 0; i < n; i++)
370        if (!TEST_ptr_eq(sk_SU_value(s, i),  v + i)) {
371            TEST_info("SU pointer check %d", i);
372            goto end;
373        }
374
375    testresult = 1;
376end:
377    sk_SU_free(s);
378    return testresult;
379}
380
381int setup_tests(void)
382{
383    ADD_ALL_TESTS(test_int_stack, 4);
384    ADD_ALL_TESTS(test_uchar_stack, 4);
385    ADD_TEST(test_SS_stack);
386    ADD_TEST(test_SU_stack);
387    return 1;
388}
389