1/*
2 * Copyright 2019-2021 The OpenSSL Project Authors. All Rights Reserved.
3 * Copyright (c) 2019, Oracle and/or its affiliates.  All rights reserved.
4 *
5 * Licensed under the Apache License 2.0 (the "License").  You may not use
6 * this file except in compliance with the License.  You can obtain a copy
7 * in the file LICENSE in the source distribution or at
8 * https://www.openssl.org/source/license.html
9 */
10
11#include <stdio.h>
12#include <string.h>
13#include <limits.h>
14
15#include <openssl/crypto.h>
16#include "internal/nelem.h"
17#include "crypto/sparse_array.h"
18#include "testutil.h"
19
20/* The macros below generate unused functions which error out one of the clang
21 * builds.  We disable this check here.
22 */
23#ifdef __clang__
24#pragma clang diagnostic ignored "-Wunused-function"
25#endif
26
27DEFINE_SPARSE_ARRAY_OF(char);
28
29static int test_sparse_array(void)
30{
31    static const struct {
32        ossl_uintmax_t n;
33        char *v;
34    } cases[] = {
35        { 22, "a" }, { 0, "z" }, { 1, "b" }, { 290, "c" },
36        { INT_MAX, "m" }, { 6666666, "d" }, { (ossl_uintmax_t)-1, "H" },
37        { 99, "e" }
38    };
39    SPARSE_ARRAY_OF(char) *sa;
40    size_t i, j;
41    int res = 0;
42
43    if (!TEST_ptr(sa = ossl_sa_char_new())
44            || !TEST_ptr_null(ossl_sa_char_get(sa, 3))
45            || !TEST_ptr_null(ossl_sa_char_get(sa, 0))
46            || !TEST_ptr_null(ossl_sa_char_get(sa, UINT_MAX)))
47        goto err;
48
49    for (i = 0; i < OSSL_NELEM(cases); i++) {
50        if (!TEST_true(ossl_sa_char_set(sa, cases[i].n, cases[i].v))) {
51            TEST_note("iteration %zu", i + 1);
52            goto err;
53        }
54        for (j = 0; j <= i; j++)
55            if (!TEST_str_eq(ossl_sa_char_get(sa, cases[j].n), cases[j].v)) {
56                TEST_note("iteration %zu / %zu", i + 1, j + 1);
57                goto err;
58            }
59    }
60
61    res = 1;
62err:
63    ossl_sa_char_free(sa);
64    return res;
65}
66
67static int test_sparse_array_num(void)
68{
69    static const struct {
70        size_t num;
71        ossl_uintmax_t n;
72        char *v;
73    } cases[] = {
74        { 1, 22, "a" }, { 2, 1021, "b" }, { 3, 3, "c" }, { 2, 22, NULL },
75        { 2, 3, "d" }, { 3, 22, "e" }, { 3, 666, NULL }, { 4, 666, "f" },
76        { 3, 3, NULL }, { 2, 22, NULL }, { 1, 666, NULL }, { 2, 64000, "g" },
77        { 1, 1021, NULL }, { 0, 64000, NULL }, { 1, 23, "h" }, { 0, 23, NULL }
78    };
79    SPARSE_ARRAY_OF(char) *sa = NULL;
80    size_t i;
81    int res = 0;
82
83    if (!TEST_size_t_eq(ossl_sa_char_num(NULL), 0)
84            || !TEST_ptr(sa = ossl_sa_char_new())
85            || !TEST_size_t_eq(ossl_sa_char_num(sa), 0))
86        goto err;
87    for (i = 0; i < OSSL_NELEM(cases); i++)
88        if (!TEST_true(ossl_sa_char_set(sa, cases[i].n, cases[i].v))
89                || !TEST_size_t_eq(ossl_sa_char_num(sa), cases[i].num))
90            goto err;
91    res = 1;
92err:
93    ossl_sa_char_free(sa);
94    return res;
95}
96
97struct index_cases_st {
98    ossl_uintmax_t n;
99    char *v;
100    int del;
101};
102
103struct doall_st {
104    SPARSE_ARRAY_OF(char) *sa;
105    size_t num_cases;
106    const struct index_cases_st *cases;
107    int res;
108    int all;
109};
110
111static void leaf_check_all(ossl_uintmax_t n, char *value, void *arg)
112{
113    struct doall_st *doall_data = (struct doall_st *)arg;
114    const struct index_cases_st *cases = doall_data->cases;
115    size_t i;
116
117    doall_data->res = 0;
118    for (i = 0; i < doall_data->num_cases; i++)
119        if ((doall_data->all || !cases[i].del)
120            && n == cases[i].n && strcmp(value, cases[i].v) == 0) {
121            doall_data->res = 1;
122            return;
123        }
124    TEST_error("Index %ju with value %s not found", n, value);
125}
126
127static void leaf_delete(ossl_uintmax_t n, char *value, void *arg)
128{
129    struct doall_st *doall_data = (struct doall_st *)arg;
130    const struct index_cases_st *cases = doall_data->cases;
131    size_t i;
132
133    doall_data->res = 0;
134    for (i = 0; i < doall_data->num_cases; i++)
135        if (n == cases[i].n && strcmp(value, cases[i].v) == 0) {
136            doall_data->res = 1;
137            ossl_sa_char_set(doall_data->sa, n, NULL);
138            return;
139        }
140    TEST_error("Index %ju with value %s not found", n, value);
141}
142
143static int test_sparse_array_doall(void)
144{
145    static const struct index_cases_st cases[] = {
146        { 22, "A", 1 }, { 1021, "b", 0 }, { 3, "c", 0 }, { INT_MAX, "d", 1 },
147        { (ossl_uintmax_t)-1, "H", 0 }, { (ossl_uintmax_t)-2, "i", 1 },
148        { 666666666, "s", 1 }, { 1234567890, "t", 0 },
149    };
150    struct doall_st doall_data;
151    size_t i;
152    SPARSE_ARRAY_OF(char) *sa = NULL;
153    int res = 0;
154
155    if (!TEST_ptr(sa = ossl_sa_char_new()))
156        goto err;
157    doall_data.num_cases = OSSL_NELEM(cases);
158    doall_data.cases = cases;
159    doall_data.all = 1;
160    doall_data.sa = NULL;
161    for (i = 0; i <  OSSL_NELEM(cases); i++)
162        if (!TEST_true(ossl_sa_char_set(sa, cases[i].n, cases[i].v))) {
163            TEST_note("failed at iteration %zu", i + 1);
164            goto err;
165    }
166
167    ossl_sa_char_doall_arg(sa, &leaf_check_all, &doall_data);
168    if (doall_data.res == 0) {
169        TEST_info("while checking all elements");
170        goto err;
171    }
172    doall_data.all = 0;
173    doall_data.sa = sa;
174    ossl_sa_char_doall_arg(sa, &leaf_delete, &doall_data);
175    if (doall_data.res == 0) {
176        TEST_info("while deleting selected elements");
177        goto err;
178    }
179    ossl_sa_char_doall_arg(sa, &leaf_check_all, &doall_data);
180    if (doall_data.res == 0) {
181        TEST_info("while checking for deleted elements");
182        goto err;
183    }
184    res = 1;
185
186err:
187    ossl_sa_char_free(sa);
188    return res;
189}
190
191int setup_tests(void)
192{
193    ADD_TEST(test_sparse_array);
194    ADD_TEST(test_sparse_array_num);
195    ADD_TEST(test_sparse_array_doall);
196    return 1;
197}
198