1/*
2 * Copyright 2016-2020 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 * https://www.openssl.org/source/license.html
8 * or in the file LICENSE in the source distribution.
9 */
10
11#include <string.h>
12#include <openssl/ssl.h>
13#include <openssl/bio.h>
14#include <openssl/err.h>
15
16#include "internal/packet.h"
17
18#include "helpers/ssltestlib.h"
19#include "testutil.h"
20
21/* Should we fragment records or not? 0 = no, !0 = yes*/
22static int fragment = 0;
23
24static char *cert = NULL;
25static char *privkey = NULL;
26
27static int async_new(BIO *bi);
28static int async_free(BIO *a);
29static int async_read(BIO *b, char *out, int outl);
30static int async_write(BIO *b, const char *in, int inl);
31static long async_ctrl(BIO *b, int cmd, long num, void *ptr);
32static int async_gets(BIO *bp, char *buf, int size);
33static int async_puts(BIO *bp, const char *str);
34
35/* Choose a sufficiently large type likely to be unused for this custom BIO */
36# define BIO_TYPE_ASYNC_FILTER  (0x80 | BIO_TYPE_FILTER)
37
38static BIO_METHOD *methods_async = NULL;
39
40struct async_ctrs {
41    unsigned int rctr;
42    unsigned int wctr;
43};
44
45static const BIO_METHOD *bio_f_async_filter(void)
46{
47    if (methods_async == NULL) {
48        methods_async = BIO_meth_new(BIO_TYPE_ASYNC_FILTER, "Async filter");
49        if (   methods_async == NULL
50            || !BIO_meth_set_write(methods_async, async_write)
51            || !BIO_meth_set_read(methods_async, async_read)
52            || !BIO_meth_set_puts(methods_async, async_puts)
53            || !BIO_meth_set_gets(methods_async, async_gets)
54            || !BIO_meth_set_ctrl(methods_async, async_ctrl)
55            || !BIO_meth_set_create(methods_async, async_new)
56            || !BIO_meth_set_destroy(methods_async, async_free))
57            return NULL;
58    }
59    return methods_async;
60}
61
62static int async_new(BIO *bio)
63{
64    struct async_ctrs *ctrs;
65
66    ctrs = OPENSSL_zalloc(sizeof(struct async_ctrs));
67    if (ctrs == NULL)
68        return 0;
69
70    BIO_set_data(bio, ctrs);
71    BIO_set_init(bio, 1);
72    return 1;
73}
74
75static int async_free(BIO *bio)
76{
77    struct async_ctrs *ctrs;
78
79    if (bio == NULL)
80        return 0;
81    ctrs = BIO_get_data(bio);
82    OPENSSL_free(ctrs);
83    BIO_set_data(bio, NULL);
84    BIO_set_init(bio, 0);
85
86    return 1;
87}
88
89static int async_read(BIO *bio, char *out, int outl)
90{
91    struct async_ctrs *ctrs;
92    int ret = 0;
93    BIO *next = BIO_next(bio);
94
95    if (outl <= 0)
96        return 0;
97    if (next == NULL)
98        return 0;
99
100    ctrs = BIO_get_data(bio);
101
102    BIO_clear_retry_flags(bio);
103
104    if (ctrs->rctr > 0) {
105        ret = BIO_read(next, out, 1);
106        if (ret <= 0 && BIO_should_read(next))
107            BIO_set_retry_read(bio);
108        ctrs->rctr = 0;
109    } else {
110        ctrs->rctr++;
111        BIO_set_retry_read(bio);
112    }
113
114    return ret;
115}
116
117#define MIN_RECORD_LEN  6
118
119#define CONTENTTYPEPOS  0
120#define VERSIONHIPOS    1
121#define VERSIONLOPOS    2
122#define DATAPOS         5
123
124static int async_write(BIO *bio, const char *in, int inl)
125{
126    struct async_ctrs *ctrs;
127    int ret = 0;
128    size_t written = 0;
129    BIO *next = BIO_next(bio);
130
131    if (inl <= 0)
132        return 0;
133    if (next == NULL)
134        return 0;
135
136    ctrs = BIO_get_data(bio);
137
138    BIO_clear_retry_flags(bio);
139
140    if (ctrs->wctr > 0) {
141        ctrs->wctr = 0;
142        if (fragment) {
143            PACKET pkt;
144
145            if (!PACKET_buf_init(&pkt, (const unsigned char *)in, inl))
146                return -1;
147
148            while (PACKET_remaining(&pkt) > 0) {
149                PACKET payload, wholebody, sessionid, extensions;
150                unsigned int contenttype, versionhi, versionlo, data;
151                unsigned int msgtype = 0, negversion = 0;
152
153                if (!PACKET_get_1(&pkt, &contenttype)
154                        || !PACKET_get_1(&pkt, &versionhi)
155                        || !PACKET_get_1(&pkt, &versionlo)
156                        || !PACKET_get_length_prefixed_2(&pkt, &payload))
157                    return -1;
158
159                /* Pretend we wrote out the record header */
160                written += SSL3_RT_HEADER_LENGTH;
161
162                wholebody = payload;
163                if (contenttype == SSL3_RT_HANDSHAKE
164                        && !PACKET_get_1(&wholebody, &msgtype))
165                    return -1;
166
167                if (msgtype == SSL3_MT_SERVER_HELLO) {
168                    if (!PACKET_forward(&wholebody,
169                                            SSL3_HM_HEADER_LENGTH - 1)
170                            || !PACKET_get_net_2(&wholebody, &negversion)
171                               /* Skip random (32 bytes) */
172                            || !PACKET_forward(&wholebody, 32)
173                               /* Skip session id */
174                            || !PACKET_get_length_prefixed_1(&wholebody,
175                                                             &sessionid)
176                               /*
177                                * Skip ciphersuite (2 bytes) and compression
178                                * method (1 byte)
179                                */
180                            || !PACKET_forward(&wholebody, 2 + 1)
181                            || !PACKET_get_length_prefixed_2(&wholebody,
182                                                             &extensions))
183                        return -1;
184
185                    /*
186                     * Find the negotiated version in supported_versions
187                     * extension, if present.
188                     */
189                    while (PACKET_remaining(&extensions)) {
190                        unsigned int type;
191                        PACKET extbody;
192
193                        if (!PACKET_get_net_2(&extensions, &type)
194                                || !PACKET_get_length_prefixed_2(&extensions,
195                                &extbody))
196                            return -1;
197
198                        if (type == TLSEXT_TYPE_supported_versions
199                                && (!PACKET_get_net_2(&extbody, &negversion)
200                                    || PACKET_remaining(&extbody) != 0))
201                            return -1;
202                    }
203                }
204
205                while (PACKET_get_1(&payload, &data)) {
206                    /* Create a new one byte long record for each byte in the
207                     * record in the input buffer
208                     */
209                    char smallrec[MIN_RECORD_LEN] = {
210                        0, /* Content type */
211                        0, /* Version hi */
212                        0, /* Version lo */
213                        0, /* Length hi */
214                        1, /* Length lo */
215                        0  /* Data */
216                    };
217
218                    smallrec[CONTENTTYPEPOS] = contenttype;
219                    smallrec[VERSIONHIPOS] = versionhi;
220                    smallrec[VERSIONLOPOS] = versionlo;
221                    smallrec[DATAPOS] = data;
222                    ret = BIO_write(next, smallrec, MIN_RECORD_LEN);
223                    if (ret <= 0)
224                        return -1;
225                    written++;
226                }
227                /*
228                 * We can't fragment anything after the ServerHello (or CCS <=
229                 * TLS1.2), otherwise we get a bad record MAC
230                 */
231                if (contenttype == SSL3_RT_CHANGE_CIPHER_SPEC
232                        || (negversion == TLS1_3_VERSION
233                            && msgtype == SSL3_MT_SERVER_HELLO)) {
234                    fragment = 0;
235                    break;
236                }
237            }
238        }
239        /* Write any data we have left after fragmenting */
240        ret = 0;
241        if ((int)written < inl) {
242            ret = BIO_write(next, in + written, inl - written);
243        }
244
245        if (ret <= 0 && BIO_should_write(next))
246            BIO_set_retry_write(bio);
247        else
248            ret += written;
249    } else {
250        ctrs->wctr++;
251        BIO_set_retry_write(bio);
252    }
253
254    return ret;
255}
256
257static long async_ctrl(BIO *bio, int cmd, long num, void *ptr)
258{
259    long ret;
260    BIO *next = BIO_next(bio);
261
262    if (next == NULL)
263        return 0;
264
265    switch (cmd) {
266    case BIO_CTRL_DUP:
267        ret = 0L;
268        break;
269    default:
270        ret = BIO_ctrl(next, cmd, num, ptr);
271        break;
272    }
273    return ret;
274}
275
276static int async_gets(BIO *bio, char *buf, int size)
277{
278    /* We don't support this - not needed anyway */
279    return -1;
280}
281
282static int async_puts(BIO *bio, const char *str)
283{
284    return async_write(bio, str, strlen(str));
285}
286
287#define MAX_ATTEMPTS    100
288
289static int test_asyncio(int test)
290{
291    SSL_CTX *serverctx = NULL, *clientctx = NULL;
292    SSL *serverssl = NULL, *clientssl = NULL;
293    BIO *s_to_c_fbio = NULL, *c_to_s_fbio = NULL;
294    int testresult = 0, ret;
295    size_t i, j;
296    const char testdata[] = "Test data";
297    char buf[sizeof(testdata)];
298
299    if (!TEST_true(create_ssl_ctx_pair(NULL, TLS_server_method(),
300                                       TLS_client_method(),
301                                       TLS1_VERSION, 0,
302                                       &serverctx, &clientctx, cert, privkey)))
303        goto end;
304
305    /*
306     * We do 2 test runs. The first time around we just do a normal handshake
307     * with lots of async io going on. The second time around we also break up
308     * all records so that the content is only one byte length (up until the
309     * CCS)
310     */
311    if (test == 1)
312        fragment = 1;
313
314
315    s_to_c_fbio = BIO_new(bio_f_async_filter());
316    c_to_s_fbio = BIO_new(bio_f_async_filter());
317    if (!TEST_ptr(s_to_c_fbio)
318            || !TEST_ptr(c_to_s_fbio)) {
319        BIO_free(s_to_c_fbio);
320        BIO_free(c_to_s_fbio);
321        goto end;
322    }
323
324    /* BIOs get freed on error */
325    if (!TEST_true(create_ssl_objects(serverctx, clientctx, &serverssl,
326                                      &clientssl, s_to_c_fbio, c_to_s_fbio))
327            || !TEST_true(create_ssl_connection(serverssl, clientssl,
328                          SSL_ERROR_NONE)))
329        goto end;
330
331    /*
332     * Send and receive some test data. Do the whole thing twice to ensure
333     * we hit at least one async event in both reading and writing
334     */
335    for (j = 0; j < 2; j++) {
336        int len;
337
338        /*
339         * Write some test data. It should never take more than 2 attempts
340         * (the first one might be a retryable fail).
341         */
342        for (ret = -1, i = 0, len = 0; len != sizeof(testdata) && i < 2;
343            i++) {
344            ret = SSL_write(clientssl, testdata + len,
345                sizeof(testdata) - len);
346            if (ret > 0) {
347                len += ret;
348            } else {
349                int ssl_error = SSL_get_error(clientssl, ret);
350
351                if (!TEST_false(ssl_error == SSL_ERROR_SYSCALL ||
352                                ssl_error == SSL_ERROR_SSL))
353                    goto end;
354            }
355        }
356        if (!TEST_size_t_eq(len, sizeof(testdata)))
357            goto end;
358
359        /*
360         * Now read the test data. It may take more attempts here because
361         * it could fail once for each byte read, including all overhead
362         * bytes from the record header/padding etc.
363         */
364        for (ret = -1, i = 0, len = 0; len != sizeof(testdata) &&
365                i < MAX_ATTEMPTS; i++) {
366            ret = SSL_read(serverssl, buf + len, sizeof(buf) - len);
367            if (ret > 0) {
368                len += ret;
369            } else {
370                int ssl_error = SSL_get_error(serverssl, ret);
371
372                if (!TEST_false(ssl_error == SSL_ERROR_SYSCALL ||
373                                ssl_error == SSL_ERROR_SSL))
374                    goto end;
375            }
376        }
377        if (!TEST_mem_eq(testdata, sizeof(testdata), buf, len))
378            goto end;
379    }
380
381    /* Also frees the BIOs */
382    SSL_free(clientssl);
383    SSL_free(serverssl);
384    clientssl = serverssl = NULL;
385
386    testresult = 1;
387
388 end:
389    SSL_free(clientssl);
390    SSL_free(serverssl);
391    SSL_CTX_free(clientctx);
392    SSL_CTX_free(serverctx);
393
394    return testresult;
395}
396
397OPT_TEST_DECLARE_USAGE("certname privkey\n")
398
399int setup_tests(void)
400{
401    if (!test_skip_common_options()) {
402        TEST_error("Error parsing test options\n");
403        return 0;
404    }
405
406    if (!TEST_ptr(cert = test_get_argument(0))
407            || !TEST_ptr(privkey = test_get_argument(1)))
408        return 0;
409
410    ADD_ALL_TESTS(test_asyncio, 2);
411    return 1;
412}
413
414void cleanup_tests(void)
415{
416    BIO_meth_free(methods_async);
417}
418