1/*
2 * Copyright (c) Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 * You may select, at your option, one of the above-listed licenses.
9 */
10
11/// Zstandard educational decoder implementation
12/// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
13
14#include <stdint.h>   // uint8_t, etc.
15#include <stdlib.h>   // malloc, free, exit
16#include <stdio.h>    // fprintf
17#include <string.h>   // memset, memcpy
18#include "zstd_decompress.h"
19
20
21/******* IMPORTANT CONSTANTS *********************************************/
22
23// Zstandard frame
24// "Magic_Number
25// 4 Bytes, little-endian format. Value : 0xFD2FB528"
26#define ZSTD_MAGIC_NUMBER 0xFD2FB528U
27
28// The size of `Block_Content` is limited by `Block_Maximum_Size`,
29#define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024)
30
31// literal blocks can't be larger than their block
32#define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX
33
34
35/******* UTILITY MACROS AND TYPES *********************************************/
36#define MAX(a, b) ((a) > (b) ? (a) : (b))
37#define MIN(a, b) ((a) < (b) ? (a) : (b))
38
39#if defined(ZDEC_NO_MESSAGE)
40#define MESSAGE(...)
41#else
42#define MESSAGE(...)  fprintf(stderr, "" __VA_ARGS__)
43#endif
44
45/// This decoder calls exit(1) when it encounters an error, however a production
46/// library should propagate error codes
47#define ERROR(s)                                                               \
48    do {                                                                       \
49        MESSAGE("Error: %s\n", s);                                     \
50        exit(1);                                                               \
51    } while (0)
52#define INP_SIZE()                                                             \
53    ERROR("Input buffer smaller than it should be or input is "                \
54          "corrupted")
55#define OUT_SIZE() ERROR("Output buffer too small for output")
56#define CORRUPTION() ERROR("Corruption detected while decompressing")
57#define BAD_ALLOC() ERROR("Memory allocation error")
58#define IMPOSSIBLE() ERROR("An impossibility has occurred")
59
60typedef uint8_t  u8;
61typedef uint16_t u16;
62typedef uint32_t u32;
63typedef uint64_t u64;
64
65typedef int8_t  i8;
66typedef int16_t i16;
67typedef int32_t i32;
68typedef int64_t i64;
69/******* END UTILITY MACROS AND TYPES *****************************************/
70
71/******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
72/// The implementations for these functions can be found at the bottom of this
73/// file.  They implement low-level functionality needed for the higher level
74/// decompression functions.
75
76/*** IO STREAM OPERATIONS *************/
77
78/// ostream_t/istream_t are used to wrap the pointers/length data passed into
79/// ZSTD_decompress, so that all IO operations are safely bounds checked
80/// They are written/read forward, and reads are treated as little-endian
81/// They should be used opaquely to ensure safety
82typedef struct {
83    u8 *ptr;
84    size_t len;
85} ostream_t;
86
87typedef struct {
88    const u8 *ptr;
89    size_t len;
90
91    // Input often reads a few bits at a time, so maintain an internal offset
92    int bit_offset;
93} istream_t;
94
95/// The following two functions are the only ones that allow the istream to be
96/// non-byte aligned
97
98/// Reads `num` bits from a bitstream, and updates the internal offset
99static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
100/// Backs-up the stream by `num` bits so they can be read again
101static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
102/// If the remaining bits in a byte will be unused, advance to the end of the
103/// byte
104static inline void IO_align_stream(istream_t *const in);
105
106/// Write the given byte into the output stream
107static inline void IO_write_byte(ostream_t *const out, u8 symb);
108
109/// Returns the number of bytes left to be read in this stream.  The stream must
110/// be byte aligned.
111static inline size_t IO_istream_len(const istream_t *const in);
112
113/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
114/// was skipped.  The stream must be byte aligned.
115static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
116/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
117/// was skipped so it can be written to.
118static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
119
120/// Advance the inner state by `len` bytes.  The stream must be byte aligned.
121static inline void IO_advance_input(istream_t *const in, size_t len);
122
123/// Returns an `ostream_t` constructed from the given pointer and length.
124static inline ostream_t IO_make_ostream(u8 *out, size_t len);
125/// Returns an `istream_t` constructed from the given pointer and length.
126static inline istream_t IO_make_istream(const u8 *in, size_t len);
127
128/// Returns an `istream_t` with the same base as `in`, and length `len`.
129/// Then, advance `in` to account for the consumed bytes.
130/// `in` must be byte aligned.
131static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
132/*** END IO STREAM OPERATIONS *********/
133
134/*** BITSTREAM OPERATIONS *************/
135/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
136/// and return them interpreted as a little-endian unsigned integer.
137static inline u64 read_bits_LE(const u8 *src, const int num_bits,
138                               const size_t offset);
139
140/// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
141/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
142/// `src + offset`.  If the offset becomes negative, the extra bits at the
143/// bottom are filled in with `0` bits instead of reading from before `src`.
144static inline u64 STREAM_read_bits(const u8 *src, const int bits,
145                                   i64 *const offset);
146/*** END BITSTREAM OPERATIONS *********/
147
148/*** BIT COUNTING OPERATIONS **********/
149/// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
150static inline int highest_set_bit(const u64 num);
151/*** END BIT COUNTING OPERATIONS ******/
152
153/*** HUFFMAN PRIMITIVES ***************/
154// Table decode method uses exponential memory, so we need to limit depth
155#define HUF_MAX_BITS (16)
156
157// Limit the maximum number of symbols to 256 so we can store a symbol in a byte
158#define HUF_MAX_SYMBS (256)
159
160/// Structure containing all tables necessary for efficient Huffman decoding
161typedef struct {
162    u8 *symbols;
163    u8 *num_bits;
164    int max_bits;
165} HUF_dtable;
166
167/// Decode a single symbol and read in enough bits to refresh the state
168static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
169                                   u16 *const state, const u8 *const src,
170                                   i64 *const offset);
171/// Read in a full state's worth of bits to initialize it
172static inline void HUF_init_state(const HUF_dtable *const dtable,
173                                  u16 *const state, const u8 *const src,
174                                  i64 *const offset);
175
176/// Decompresses a single Huffman stream, returns the number of bytes decoded.
177/// `src_len` must be the exact length of the Huffman-coded block.
178static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
179                                     ostream_t *const out, istream_t *const in);
180/// Same as previous but decodes 4 streams, formatted as in the Zstandard
181/// specification.
182/// `src_len` must be the exact length of the Huffman-coded block.
183static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
184                                     ostream_t *const out, istream_t *const in);
185
186/// Initialize a Huffman decoding table using the table of bit counts provided
187static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
188                            const int num_symbs);
189/// Initialize a Huffman decoding table using the table of weights provided
190/// Weights follow the definition provided in the Zstandard specification
191static void HUF_init_dtable_usingweights(HUF_dtable *const table,
192                                         const u8 *const weights,
193                                         const int num_symbs);
194
195/// Free the malloc'ed parts of a decoding table
196static void HUF_free_dtable(HUF_dtable *const dtable);
197/*** END HUFFMAN PRIMITIVES ***********/
198
199/*** FSE PRIMITIVES *******************/
200/// For more description of FSE see
201/// https://github.com/Cyan4973/FiniteStateEntropy/
202
203// FSE table decoding uses exponential memory, so limit the maximum accuracy
204#define FSE_MAX_ACCURACY_LOG (15)
205// Limit the maximum number of symbols so they can be stored in a single byte
206#define FSE_MAX_SYMBS (256)
207
208/// The tables needed to decode FSE encoded streams
209typedef struct {
210    u8 *symbols;
211    u8 *num_bits;
212    u16 *new_state_base;
213    int accuracy_log;
214} FSE_dtable;
215
216/// Return the symbol for the current state
217static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
218                                 const u16 state);
219/// Read the number of bits necessary to update state, update, and shift offset
220/// back to reflect the bits read
221static inline void FSE_update_state(const FSE_dtable *const dtable,
222                                    u16 *const state, const u8 *const src,
223                                    i64 *const offset);
224
225/// Combine peek and update: decode a symbol and update the state
226static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
227                                   u16 *const state, const u8 *const src,
228                                   i64 *const offset);
229
230/// Read bits from the stream to initialize the state and shift offset back
231static inline void FSE_init_state(const FSE_dtable *const dtable,
232                                  u16 *const state, const u8 *const src,
233                                  i64 *const offset);
234
235/// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
236/// using an FSE decoding table.  `src_len` must be the exact length of the
237/// block.
238static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
239                                          ostream_t *const out,
240                                          istream_t *const in);
241
242/// Initialize a decoding table using normalized frequencies.
243static void FSE_init_dtable(FSE_dtable *const dtable,
244                            const i16 *const norm_freqs, const int num_symbs,
245                            const int accuracy_log);
246
247/// Decode an FSE header as defined in the Zstandard format specification and
248/// use the decoded frequencies to initialize a decoding table.
249static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
250                                const int max_accuracy_log);
251
252/// Initialize an FSE table that will always return the same symbol and consume
253/// 0 bits per symbol, to be used for RLE mode in sequence commands
254static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
255
256/// Free the malloc'ed parts of a decoding table
257static void FSE_free_dtable(FSE_dtable *const dtable);
258/*** END FSE PRIMITIVES ***************/
259
260/******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
261
262/******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
263
264/// A small structure that can be reused in various places that need to access
265/// frame header information
266typedef struct {
267    // The size of window that we need to be able to contiguously store for
268    // references
269    size_t window_size;
270    // The total output size of this compressed frame
271    size_t frame_content_size;
272
273    // The dictionary id if this frame uses one
274    u32 dictionary_id;
275
276    // Whether or not the content of this frame has a checksum
277    int content_checksum_flag;
278    // Whether or not the output for this frame is in a single segment
279    int single_segment_flag;
280} frame_header_t;
281
282/// The context needed to decode blocks in a frame
283typedef struct {
284    frame_header_t header;
285
286    // The total amount of data available for backreferences, to determine if an
287    // offset too large to be correct
288    size_t current_total_output;
289
290    const u8 *dict_content;
291    size_t dict_content_len;
292
293    // Entropy encoding tables so they can be repeated by future blocks instead
294    // of retransmitting
295    HUF_dtable literals_dtable;
296    FSE_dtable ll_dtable;
297    FSE_dtable ml_dtable;
298    FSE_dtable of_dtable;
299
300    // The last 3 offsets for the special "repeat offsets".
301    u64 previous_offsets[3];
302} frame_context_t;
303
304/// The decoded contents of a dictionary so that it doesn't have to be repeated
305/// for each frame that uses it
306struct dictionary_s {
307    // Entropy tables
308    HUF_dtable literals_dtable;
309    FSE_dtable ll_dtable;
310    FSE_dtable ml_dtable;
311    FSE_dtable of_dtable;
312
313    // Raw content for backreferences
314    u8 *content;
315    size_t content_size;
316
317    // Offset history to prepopulate the frame's history
318    u64 previous_offsets[3];
319
320    u32 dictionary_id;
321};
322
323/// A tuple containing the parts necessary to decode and execute a ZSTD sequence
324/// command
325typedef struct {
326    u32 literal_length;
327    u32 match_length;
328    u32 offset;
329} sequence_command_t;
330
331/// The decoder works top-down, starting at the high level like Zstd frames, and
332/// working down to lower more technical levels such as blocks, literals, and
333/// sequences.  The high-level functions roughly follow the outline of the
334/// format specification:
335/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
336
337/// Before the implementation of each high-level function declared here, the
338/// prototypes for their helper functions are defined and explained
339
340/// Decode a single Zstd frame, or error if the input is not a valid frame.
341/// Accepts a dict argument, which may be NULL indicating no dictionary.
342/// See
343/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
344static void decode_frame(ostream_t *const out, istream_t *const in,
345                         const dictionary_t *const dict);
346
347// Decode data in a compressed block
348static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
349                             istream_t *const in);
350
351// Decode the literals section of a block
352static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
353                              u8 **const literals);
354
355// Decode the sequences part of a block
356static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
357                               sequence_command_t **const sequences);
358
359// Execute the decoded sequences on the literals block
360static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
361                              const u8 *const literals,
362                              const size_t literals_len,
363                              const sequence_command_t *const sequences,
364                              const size_t num_sequences);
365
366// Copies literals and returns the total literal length that was copied
367static u32 copy_literals(const size_t seq, istream_t *litstream,
368                         ostream_t *const out);
369
370// Given an offset code from a sequence command (either an actual offset value
371// or an index for previous offset), computes the correct offset and updates
372// the offset history
373static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
374
375// Given an offset, match length, and total output, as well as the frame
376// context for the dictionary, determines if the dictionary is used and
377// executes the copy operation
378static void execute_match_copy(frame_context_t *const ctx, size_t offset,
379                              size_t match_length, size_t total_output,
380                              ostream_t *const out);
381
382/******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
383
384size_t ZSTD_decompress(void *const dst, const size_t dst_len,
385                       const void *const src, const size_t src_len) {
386    dictionary_t* const uninit_dict = create_dictionary();
387    size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
388                                                         src_len, uninit_dict);
389    free_dictionary(uninit_dict);
390    return decomp_size;
391}
392
393size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
394                                 const void *const src, const size_t src_len,
395                                 dictionary_t* parsed_dict) {
396
397    istream_t in = IO_make_istream(src, src_len);
398    ostream_t out = IO_make_ostream(dst, dst_len);
399
400    // "A content compressed by Zstandard is transformed into a Zstandard frame.
401    // Multiple frames can be appended into a single file or stream. A frame is
402    // totally independent, has a defined beginning and end, and a set of
403    // parameters which tells the decoder how to decompress it."
404
405    /* this decoder assumes decompression of a single frame */
406    decode_frame(&out, &in, parsed_dict);
407
408    return (size_t)(out.ptr - (u8 *)dst);
409}
410
411/******* FRAME DECODING ******************************************************/
412
413static void decode_data_frame(ostream_t *const out, istream_t *const in,
414                              const dictionary_t *const dict);
415static void init_frame_context(frame_context_t *const context,
416                               istream_t *const in,
417                               const dictionary_t *const dict);
418static void free_frame_context(frame_context_t *const context);
419static void parse_frame_header(frame_header_t *const header,
420                               istream_t *const in);
421static void frame_context_apply_dict(frame_context_t *const ctx,
422                                     const dictionary_t *const dict);
423
424static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
425                            istream_t *const in);
426
427static void decode_frame(ostream_t *const out, istream_t *const in,
428                         const dictionary_t *const dict) {
429    const u32 magic_number = (u32)IO_read_bits(in, 32);
430    if (magic_number == ZSTD_MAGIC_NUMBER) {
431        // ZSTD frame
432        decode_data_frame(out, in, dict);
433
434        return;
435    }
436
437    // not a real frame or a skippable frame
438    ERROR("Tried to decode non-ZSTD frame");
439}
440
441/// Decode a frame that contains compressed data.  Not all frames do as there
442/// are skippable frames.
443/// See
444/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
445static void decode_data_frame(ostream_t *const out, istream_t *const in,
446                              const dictionary_t *const dict) {
447    frame_context_t ctx;
448
449    // Initialize the context that needs to be carried from block to block
450    init_frame_context(&ctx, in, dict);
451
452    if (ctx.header.frame_content_size != 0 &&
453        ctx.header.frame_content_size > out->len) {
454        OUT_SIZE();
455    }
456
457    decompress_data(&ctx, out, in);
458
459    free_frame_context(&ctx);
460}
461
462/// Takes the information provided in the header and dictionary, and initializes
463/// the context for this frame
464static void init_frame_context(frame_context_t *const context,
465                               istream_t *const in,
466                               const dictionary_t *const dict) {
467    // Most fields in context are correct when initialized to 0
468    memset(context, 0, sizeof(frame_context_t));
469
470    // Parse data from the frame header
471    parse_frame_header(&context->header, in);
472
473    // Set up the offset history for the repeat offset commands
474    context->previous_offsets[0] = 1;
475    context->previous_offsets[1] = 4;
476    context->previous_offsets[2] = 8;
477
478    // Apply details from the dict if it exists
479    frame_context_apply_dict(context, dict);
480}
481
482static void free_frame_context(frame_context_t *const context) {
483    HUF_free_dtable(&context->literals_dtable);
484
485    FSE_free_dtable(&context->ll_dtable);
486    FSE_free_dtable(&context->ml_dtable);
487    FSE_free_dtable(&context->of_dtable);
488
489    memset(context, 0, sizeof(frame_context_t));
490}
491
492static void parse_frame_header(frame_header_t *const header,
493                               istream_t *const in) {
494    // "The first header's byte is called the Frame_Header_Descriptor. It tells
495    // which other fields are present. Decoding this byte is enough to tell the
496    // size of Frame_Header.
497    //
498    // Bit number   Field name
499    // 7-6  Frame_Content_Size_flag
500    // 5    Single_Segment_flag
501    // 4    Unused_bit
502    // 3    Reserved_bit
503    // 2    Content_Checksum_flag
504    // 1-0  Dictionary_ID_flag"
505    const u8 descriptor = (u8)IO_read_bits(in, 8);
506
507    // decode frame header descriptor into flags
508    const u8 frame_content_size_flag = descriptor >> 6;
509    const u8 single_segment_flag = (descriptor >> 5) & 1;
510    const u8 reserved_bit = (descriptor >> 3) & 1;
511    const u8 content_checksum_flag = (descriptor >> 2) & 1;
512    const u8 dictionary_id_flag = descriptor & 3;
513
514    if (reserved_bit != 0) {
515        CORRUPTION();
516    }
517
518    header->single_segment_flag = single_segment_flag;
519    header->content_checksum_flag = content_checksum_flag;
520
521    // decode window size
522    if (!single_segment_flag) {
523        // "Provides guarantees on maximum back-reference distance that will be
524        // used within compressed data. This information is important for
525        // decoders to allocate enough memory.
526        //
527        // Bit numbers  7-3         2-0
528        // Field name   Exponent    Mantissa"
529        u8 window_descriptor = (u8)IO_read_bits(in, 8);
530        u8 exponent = window_descriptor >> 3;
531        u8 mantissa = window_descriptor & 7;
532
533        // Use the algorithm from the specification to compute window size
534        // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
535        size_t window_base = (size_t)1 << (10 + exponent);
536        size_t window_add = (window_base / 8) * mantissa;
537        header->window_size = window_base + window_add;
538    }
539
540    // decode dictionary id if it exists
541    if (dictionary_id_flag) {
542        // "This is a variable size field, which contains the ID of the
543        // dictionary required to properly decode the frame. Note that this
544        // field is optional. When it's not present, it's up to the caller to
545        // make sure it uses the correct dictionary. Format is little-endian."
546        const int bytes_array[] = {0, 1, 2, 4};
547        const int bytes = bytes_array[dictionary_id_flag];
548
549        header->dictionary_id = (u32)IO_read_bits(in, bytes * 8);
550    } else {
551        header->dictionary_id = 0;
552    }
553
554    // decode frame content size if it exists
555    if (single_segment_flag || frame_content_size_flag) {
556        // "This is the original (uncompressed) size. This information is
557        // optional. The Field_Size is provided according to value of
558        // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
559        // present), 1, 2, 4 or 8 bytes. Format is little-endian."
560        //
561        // if frame_content_size_flag == 0 but single_segment_flag is set, we
562        // still have a 1 byte field
563        const int bytes_array[] = {1, 2, 4, 8};
564        const int bytes = bytes_array[frame_content_size_flag];
565
566        header->frame_content_size = IO_read_bits(in, bytes * 8);
567        if (bytes == 2) {
568            // "When Field_Size is 2, the offset of 256 is added."
569            header->frame_content_size += 256;
570        }
571    } else {
572        header->frame_content_size = 0;
573    }
574
575    if (single_segment_flag) {
576        // "The Window_Descriptor byte is optional. It is absent when
577        // Single_Segment_flag is set. In this case, the maximum back-reference
578        // distance is the content size itself, which can be any value from 1 to
579        // 2^64-1 bytes (16 EB)."
580        header->window_size = header->frame_content_size;
581    }
582}
583
584/// Decompress the data from a frame block by block
585static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
586                            istream_t *const in) {
587    // "A frame encapsulates one or multiple blocks. Each block can be
588    // compressed or not, and has a guaranteed maximum content size, which
589    // depends on frame parameters. Unlike frames, each block depends on
590    // previous blocks for proper decoding. However, each block can be
591    // decompressed without waiting for its successor, allowing streaming
592    // operations."
593    int last_block = 0;
594    do {
595        // "Last_Block
596        //
597        // The lowest bit signals if this block is the last one. Frame ends
598        // right after this block.
599        //
600        // Block_Type and Block_Size
601        //
602        // The next 2 bits represent the Block_Type, while the remaining 21 bits
603        // represent the Block_Size. Format is little-endian."
604        last_block = (int)IO_read_bits(in, 1);
605        const int block_type = (int)IO_read_bits(in, 2);
606        const size_t block_len = IO_read_bits(in, 21);
607
608        switch (block_type) {
609        case 0: {
610            // "Raw_Block - this is an uncompressed block. Block_Size is the
611            // number of bytes to read and copy."
612            const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
613            u8 *const write_ptr = IO_get_write_ptr(out, block_len);
614
615            // Copy the raw data into the output
616            memcpy(write_ptr, read_ptr, block_len);
617
618            ctx->current_total_output += block_len;
619            break;
620        }
621        case 1: {
622            // "RLE_Block - this is a single byte, repeated N times. In which
623            // case, Block_Size is the size to regenerate, while the
624            // "compressed" block is just 1 byte (the byte to repeat)."
625            const u8 *const read_ptr = IO_get_read_ptr(in, 1);
626            u8 *const write_ptr = IO_get_write_ptr(out, block_len);
627
628            // Copy `block_len` copies of `read_ptr[0]` to the output
629            memset(write_ptr, read_ptr[0], block_len);
630
631            ctx->current_total_output += block_len;
632            break;
633        }
634        case 2: {
635            // "Compressed_Block - this is a Zstandard compressed block,
636            // detailed in another section of this specification. Block_Size is
637            // the compressed size.
638
639            // Create a sub-stream for the block
640            istream_t block_stream = IO_make_sub_istream(in, block_len);
641            decompress_block(ctx, out, &block_stream);
642            break;
643        }
644        case 3:
645            // "Reserved - this is not a block. This value cannot be used with
646            // current version of this specification."
647            CORRUPTION();
648            break;
649        default:
650            IMPOSSIBLE();
651        }
652    } while (!last_block);
653
654    if (ctx->header.content_checksum_flag) {
655        // This program does not support checking the checksum, so skip over it
656        // if it's present
657        IO_advance_input(in, 4);
658    }
659}
660/******* END FRAME DECODING ***************************************************/
661
662/******* BLOCK DECOMPRESSION **************************************************/
663static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
664                             istream_t *const in) {
665    // "A compressed block consists of 2 sections :
666    //
667    // Literals_Section
668    // Sequences_Section"
669
670
671    // Part 1: decode the literals block
672    u8 *literals = NULL;
673    const size_t literals_size = decode_literals(ctx, in, &literals);
674
675    // Part 2: decode the sequences block
676    sequence_command_t *sequences = NULL;
677    const size_t num_sequences =
678        decode_sequences(ctx, in, &sequences);
679
680    // Part 3: combine literals and sequence commands to generate output
681    execute_sequences(ctx, out, literals, literals_size, sequences,
682                      num_sequences);
683    free(literals);
684    free(sequences);
685}
686/******* END BLOCK DECOMPRESSION **********************************************/
687
688/******* LITERALS DECODING ****************************************************/
689static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
690                                     const int block_type,
691                                     const int size_format);
692static size_t decode_literals_compressed(frame_context_t *const ctx,
693                                         istream_t *const in,
694                                         u8 **const literals,
695                                         const int block_type,
696                                         const int size_format);
697static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
698static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
699                                    int *const num_symbs);
700
701static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
702                              u8 **const literals) {
703    // "Literals can be stored uncompressed or compressed using Huffman prefix
704    // codes. When compressed, an optional tree description can be present,
705    // followed by 1 or 4 streams."
706    //
707    // "Literals_Section_Header
708    //
709    // Header is in charge of describing how literals are packed. It's a
710    // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
711    // little-endian convention."
712    //
713    // "Literals_Block_Type
714    //
715    // This field uses 2 lowest bits of first byte, describing 4 different block
716    // types"
717    //
718    // size_format takes between 1 and 2 bits
719    int block_type = (int)IO_read_bits(in, 2);
720    int size_format = (int)IO_read_bits(in, 2);
721
722    if (block_type <= 1) {
723        // Raw or RLE literals block
724        return decode_literals_simple(in, literals, block_type,
725                                      size_format);
726    } else {
727        // Huffman compressed literals
728        return decode_literals_compressed(ctx, in, literals, block_type,
729                                          size_format);
730    }
731}
732
733/// Decodes literals blocks in raw or RLE form
734static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
735                                     const int block_type,
736                                     const int size_format) {
737    size_t size;
738    switch (size_format) {
739    // These cases are in the form ?0
740    // In this case, the ? bit is actually part of the size field
741    case 0:
742    case 2:
743        // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
744        IO_rewind_bits(in, 1);
745        size = IO_read_bits(in, 5);
746        break;
747    case 1:
748        // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
749        size = IO_read_bits(in, 12);
750        break;
751    case 3:
752        // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
753        size = IO_read_bits(in, 20);
754        break;
755    default:
756        // Size format is in range 0-3
757        IMPOSSIBLE();
758    }
759
760    if (size > MAX_LITERALS_SIZE) {
761        CORRUPTION();
762    }
763
764    *literals = malloc(size);
765    if (!*literals) {
766        BAD_ALLOC();
767    }
768
769    switch (block_type) {
770    case 0: {
771        // "Raw_Literals_Block - Literals are stored uncompressed."
772        const u8 *const read_ptr = IO_get_read_ptr(in, size);
773        memcpy(*literals, read_ptr, size);
774        break;
775    }
776    case 1: {
777        // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
778        const u8 *const read_ptr = IO_get_read_ptr(in, 1);
779        memset(*literals, read_ptr[0], size);
780        break;
781    }
782    default:
783        IMPOSSIBLE();
784    }
785
786    return size;
787}
788
789/// Decodes Huffman compressed literals
790static size_t decode_literals_compressed(frame_context_t *const ctx,
791                                         istream_t *const in,
792                                         u8 **const literals,
793                                         const int block_type,
794                                         const int size_format) {
795    size_t regenerated_size, compressed_size;
796    // Only size_format=0 has 1 stream, so default to 4
797    int num_streams = 4;
798    switch (size_format) {
799    case 0:
800        // "A single stream. Both Compressed_Size and Regenerated_Size use 10
801        // bits (0-1023)."
802        num_streams = 1;
803    // Fall through as it has the same size format
804        /* fallthrough */
805    case 1:
806        // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
807        // (0-1023)."
808        regenerated_size = IO_read_bits(in, 10);
809        compressed_size = IO_read_bits(in, 10);
810        break;
811    case 2:
812        // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
813        // (0-16383)."
814        regenerated_size = IO_read_bits(in, 14);
815        compressed_size = IO_read_bits(in, 14);
816        break;
817    case 3:
818        // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
819        // (0-262143)."
820        regenerated_size = IO_read_bits(in, 18);
821        compressed_size = IO_read_bits(in, 18);
822        break;
823    default:
824        // Impossible
825        IMPOSSIBLE();
826    }
827    if (regenerated_size > MAX_LITERALS_SIZE) {
828        CORRUPTION();
829    }
830
831    *literals = malloc(regenerated_size);
832    if (!*literals) {
833        BAD_ALLOC();
834    }
835
836    ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
837    istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
838
839    if (block_type == 2) {
840        // Decode the provided Huffman table
841        // "This section is only present when Literals_Block_Type type is
842        // Compressed_Literals_Block (2)."
843
844        HUF_free_dtable(&ctx->literals_dtable);
845        decode_huf_table(&ctx->literals_dtable, &huf_stream);
846    } else {
847        // If the previous Huffman table is being repeated, ensure it exists
848        if (!ctx->literals_dtable.symbols) {
849            CORRUPTION();
850        }
851    }
852
853    size_t symbols_decoded;
854    if (num_streams == 1) {
855        symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
856    } else {
857        symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
858    }
859
860    if (symbols_decoded != regenerated_size) {
861        CORRUPTION();
862    }
863
864    return regenerated_size;
865}
866
867// Decode the Huffman table description
868static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
869    // "All literal values from zero (included) to last present one (excluded)
870    // are represented by Weight with values from 0 to Max_Number_of_Bits."
871
872    // "This is a single byte value (0-255), which describes how to decode the list of weights."
873    const u8 header = IO_read_bits(in, 8);
874
875    u8 weights[HUF_MAX_SYMBS];
876    memset(weights, 0, sizeof(weights));
877
878    int num_symbs;
879
880    if (header >= 128) {
881        // "This is a direct representation, where each Weight is written
882        // directly as a 4 bits field (0-15). The full representation occupies
883        // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
884        // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
885        // 127"
886        num_symbs = header - 127;
887        const size_t bytes = (num_symbs + 1) / 2;
888
889        const u8 *const weight_src = IO_get_read_ptr(in, bytes);
890
891        for (int i = 0; i < num_symbs; i++) {
892            // "They are encoded forward, 2
893            // weights to a byte with the first weight taking the top four bits
894            // and the second taking the bottom four (e.g. the following
895            // operations could be used to read the weights: Weight[0] =
896            // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
897            if (i % 2 == 0) {
898                weights[i] = weight_src[i / 2] >> 4;
899            } else {
900                weights[i] = weight_src[i / 2] & 0xf;
901            }
902        }
903    } else {
904        // The weights are FSE encoded, decode them before we can construct the
905        // table
906        istream_t fse_stream = IO_make_sub_istream(in, header);
907        ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
908        fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
909    }
910
911    // Construct the table using the decoded weights
912    HUF_init_dtable_usingweights(dtable, weights, num_symbs);
913}
914
915static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
916                                    int *const num_symbs) {
917    const int MAX_ACCURACY_LOG = 7;
918
919    FSE_dtable dtable;
920
921    // "An FSE bitstream starts by a header, describing probabilities
922    // distribution. It will create a Decoding Table. For a list of Huffman
923    // weights, maximum accuracy is 7 bits."
924    FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
925
926    // Decode the weights
927    *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
928
929    FSE_free_dtable(&dtable);
930}
931/******* END LITERALS DECODING ************************************************/
932
933/******* SEQUENCE DECODING ****************************************************/
934/// The combination of FSE states needed to decode sequences
935typedef struct {
936    FSE_dtable ll_table;
937    FSE_dtable of_table;
938    FSE_dtable ml_table;
939
940    u16 ll_state;
941    u16 of_state;
942    u16 ml_state;
943} sequence_states_t;
944
945/// Different modes to signal to decode_seq_tables what to do
946typedef enum {
947    seq_literal_length = 0,
948    seq_offset = 1,
949    seq_match_length = 2,
950} seq_part_t;
951
952typedef enum {
953    seq_predefined = 0,
954    seq_rle = 1,
955    seq_fse = 2,
956    seq_repeat = 3,
957} seq_mode_t;
958
959/// The predefined FSE distribution tables for `seq_predefined` mode
960static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
961    4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,  1,  2,  2,
962    2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
963static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
964    1, 1, 1, 1, 1, 1, 2, 2, 2, 1,  1,  1,  1,  1, 1,
965    1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
966static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
967    1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1,  1,  1,  1,  1,  1,  1, 1,
968    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 1,
969    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
970
971/// The sequence decoding baseline and number of additional bits to read/add
972/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
973static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
974    0,  1,  2,   3,   4,   5,    6,    7,    8,    9,     10,    11,
975    12, 13, 14,  15,  16,  18,   20,   22,   24,   28,    32,    40,
976    48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
977static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
978    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  1,  1,
979    1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
980
981static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
982    3,  4,   5,   6,   7,    8,    9,    10,   11,    12,    13,   14, 15, 16,
983    17, 18,  19,  20,  21,   22,   23,   24,   25,    26,    27,   28, 29, 30,
984    31, 32,  33,  34,  35,   37,   39,   41,   43,    47,    51,   59, 67, 83,
985    99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
986static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
987    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,
988    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  1,  1,  1, 1,
989    2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
990
991/// Offset decoding is simpler so we just need a maximum code value
992static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52};
993
994static void decompress_sequences(frame_context_t *const ctx,
995                                 istream_t *const in,
996                                 sequence_command_t *const sequences,
997                                 const size_t num_sequences);
998static sequence_command_t decode_sequence(sequence_states_t *const state,
999                                          const u8 *const src,
1000                                          i64 *const offset);
1001static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1002                               const seq_part_t type, const seq_mode_t mode);
1003
1004static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
1005                               sequence_command_t **const sequences) {
1006    // "A compressed block is a succession of sequences . A sequence is a
1007    // literal copy command, followed by a match copy command. A literal copy
1008    // command specifies a length. It is the number of bytes to be copied (or
1009    // extracted) from the literal section. A match copy command specifies an
1010    // offset and a length. The offset gives the position to copy from, which
1011    // can be within a previous block."
1012
1013    size_t num_sequences;
1014
1015    // "Number_of_Sequences
1016    //
1017    // This is a variable size field using between 1 and 3 bytes. Let's call its
1018    // first byte byte0."
1019    u8 header = IO_read_bits(in, 8);
1020    if (header == 0) {
1021        // "There are no sequences. The sequence section stops there.
1022        // Regenerated content is defined entirely by literals section."
1023        *sequences = NULL;
1024        return 0;
1025    } else if (header < 128) {
1026        // "Number_of_Sequences = byte0 . Uses 1 byte."
1027        num_sequences = header;
1028    } else if (header < 255) {
1029        // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
1030        num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
1031    } else {
1032        // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
1033        num_sequences = IO_read_bits(in, 16) + 0x7F00;
1034    }
1035
1036    *sequences = malloc(num_sequences * sizeof(sequence_command_t));
1037    if (!*sequences) {
1038        BAD_ALLOC();
1039    }
1040
1041    decompress_sequences(ctx, in, *sequences, num_sequences);
1042    return num_sequences;
1043}
1044
1045/// Decompress the FSE encoded sequence commands
1046static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
1047                                 sequence_command_t *const sequences,
1048                                 const size_t num_sequences) {
1049    // "The Sequences_Section regroup all symbols required to decode commands.
1050    // There are 3 symbol types : literals lengths, offsets and match lengths.
1051    // They are encoded together, interleaved, in a single bitstream."
1052
1053    // "Symbol compression modes
1054    //
1055    // This is a single byte, defining the compression mode of each symbol
1056    // type."
1057    //
1058    // Bit number : Field name
1059    // 7-6        : Literals_Lengths_Mode
1060    // 5-4        : Offsets_Mode
1061    // 3-2        : Match_Lengths_Mode
1062    // 1-0        : Reserved
1063    u8 compression_modes = IO_read_bits(in, 8);
1064
1065    if ((compression_modes & 3) != 0) {
1066        // Reserved bits set
1067        CORRUPTION();
1068    }
1069
1070    // "Following the header, up to 3 distribution tables can be described. When
1071    // present, they are in this order :
1072    //
1073    // Literals lengths
1074    // Offsets
1075    // Match Lengths"
1076    // Update the tables we have stored in the context
1077    decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
1078                     (compression_modes >> 6) & 3);
1079
1080    decode_seq_table(&ctx->of_dtable, in, seq_offset,
1081                     (compression_modes >> 4) & 3);
1082
1083    decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
1084                     (compression_modes >> 2) & 3);
1085
1086
1087    sequence_states_t states;
1088
1089    // Initialize the decoding tables
1090    {
1091        states.ll_table = ctx->ll_dtable;
1092        states.of_table = ctx->of_dtable;
1093        states.ml_table = ctx->ml_dtable;
1094    }
1095
1096    const size_t len = IO_istream_len(in);
1097    const u8 *const src = IO_get_read_ptr(in, len);
1098
1099    // "After writing the last bit containing information, the compressor writes
1100    // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
1101    const int padding = 8 - highest_set_bit(src[len - 1]);
1102    // The offset starts at the end because FSE streams are read backwards
1103    i64 bit_offset = (i64)(len * 8 - (size_t)padding);
1104
1105    // "The bitstream starts with initial state values, each using the required
1106    // number of bits in their respective accuracy, decoded previously from
1107    // their normalized distribution.
1108    //
1109    // It starts by Literals_Length_State, followed by Offset_State, and finally
1110    // Match_Length_State."
1111    FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
1112    FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
1113    FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
1114
1115    for (size_t i = 0; i < num_sequences; i++) {
1116        // Decode sequences one by one
1117        sequences[i] = decode_sequence(&states, src, &bit_offset);
1118    }
1119
1120    if (bit_offset != 0) {
1121        CORRUPTION();
1122    }
1123}
1124
1125// Decode a single sequence and update the state
1126static sequence_command_t decode_sequence(sequence_states_t *const states,
1127                                          const u8 *const src,
1128                                          i64 *const offset) {
1129    // "Each symbol is a code in its own context, which specifies Baseline and
1130    // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
1131    // additional bits in the same bitstream."
1132
1133    // Decode symbols, but don't update states
1134    const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
1135    const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
1136    const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
1137
1138    // Offset doesn't need a max value as it's not decoded using a table
1139    if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
1140        ml_code > SEQ_MAX_CODES[seq_match_length]) {
1141        CORRUPTION();
1142    }
1143
1144    // Read the interleaved bits
1145    sequence_command_t seq;
1146    // "Decoding starts by reading the Number_of_Bits required to decode Offset.
1147    // It then does the same for Match_Length, and then for Literals_Length."
1148    seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
1149
1150    seq.match_length =
1151        SEQ_MATCH_LENGTH_BASELINES[ml_code] +
1152        STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
1153
1154    seq.literal_length =
1155        SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
1156        STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
1157
1158    // "If it is not the last sequence in the block, the next operation is to
1159    // update states. Using the rules pre-calculated in the decoding tables,
1160    // Literals_Length_State is updated, followed by Match_Length_State, and
1161    // then Offset_State."
1162    // If the stream is complete don't read bits to update state
1163    if (*offset != 0) {
1164        FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
1165        FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
1166        FSE_update_state(&states->of_table, &states->of_state, src, offset);
1167    }
1168
1169    return seq;
1170}
1171
1172/// Given a sequence part and table mode, decode the FSE distribution
1173/// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
1174static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1175                             const seq_part_t type, const seq_mode_t mode) {
1176    // Constant arrays indexed by seq_part_t
1177    const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
1178                                                SEQ_OFFSET_DEFAULT_DIST,
1179                                                SEQ_MATCH_LENGTH_DEFAULT_DIST};
1180    const size_t default_distribution_lengths[] = {36, 29, 53};
1181    const size_t default_distribution_accuracies[] = {6, 5, 6};
1182
1183    const size_t max_accuracies[] = {9, 8, 9};
1184
1185    if (mode != seq_repeat) {
1186        // Free old one before overwriting
1187        FSE_free_dtable(table);
1188    }
1189
1190    switch (mode) {
1191    case seq_predefined: {
1192        // "Predefined_Mode : uses a predefined distribution table."
1193        const i16 *distribution = default_distributions[type];
1194        const size_t symbs = default_distribution_lengths[type];
1195        const size_t accuracy_log = default_distribution_accuracies[type];
1196
1197        FSE_init_dtable(table, distribution, symbs, accuracy_log);
1198        break;
1199    }
1200    case seq_rle: {
1201        // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
1202        const u8 symb = IO_get_read_ptr(in, 1)[0];
1203        FSE_init_dtable_rle(table, symb);
1204        break;
1205    }
1206    case seq_fse: {
1207        // "FSE_Compressed_Mode : standard FSE compression. A distribution table
1208        // will be present "
1209        FSE_decode_header(table, in, max_accuracies[type]);
1210        break;
1211    }
1212    case seq_repeat:
1213        // "Repeat_Mode : re-use distribution table from previous compressed
1214        // block."
1215        // Nothing to do here, table will be unchanged
1216        if (!table->symbols) {
1217            // This mode is invalid if we don't already have a table
1218            CORRUPTION();
1219        }
1220        break;
1221    default:
1222        // Impossible, as mode is from 0-3
1223        IMPOSSIBLE();
1224        break;
1225    }
1226
1227}
1228/******* END SEQUENCE DECODING ************************************************/
1229
1230/******* SEQUENCE EXECUTION ***************************************************/
1231static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
1232                              const u8 *const literals,
1233                              const size_t literals_len,
1234                              const sequence_command_t *const sequences,
1235                              const size_t num_sequences) {
1236    istream_t litstream = IO_make_istream(literals, literals_len);
1237
1238    u64 *const offset_hist = ctx->previous_offsets;
1239    size_t total_output = ctx->current_total_output;
1240
1241    for (size_t i = 0; i < num_sequences; i++) {
1242        const sequence_command_t seq = sequences[i];
1243        {
1244            const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
1245            total_output += literals_size;
1246        }
1247
1248        size_t const offset = compute_offset(seq, offset_hist);
1249
1250        size_t const match_length = seq.match_length;
1251
1252        execute_match_copy(ctx, offset, match_length, total_output, out);
1253
1254        total_output += match_length;
1255    }
1256
1257    // Copy any leftover literals
1258    {
1259        size_t len = IO_istream_len(&litstream);
1260        copy_literals(len, &litstream, out);
1261        total_output += len;
1262    }
1263
1264    ctx->current_total_output = total_output;
1265}
1266
1267static u32 copy_literals(const size_t literal_length, istream_t *litstream,
1268                         ostream_t *const out) {
1269    // If the sequence asks for more literals than are left, the
1270    // sequence must be corrupted
1271    if (literal_length > IO_istream_len(litstream)) {
1272        CORRUPTION();
1273    }
1274
1275    u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
1276    const u8 *const read_ptr =
1277         IO_get_read_ptr(litstream, literal_length);
1278    // Copy literals to output
1279    memcpy(write_ptr, read_ptr, literal_length);
1280
1281    return literal_length;
1282}
1283
1284static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
1285    size_t offset;
1286    // Offsets are special, we need to handle the repeat offsets
1287    if (seq.offset <= 3) {
1288        // "The first 3 values define a repeated offset and we will call
1289        // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
1290        // They are sorted in recency order, with Repeated_Offset1 meaning
1291        // 'most recent one'".
1292
1293        // Use 0 indexing for the array
1294        u32 idx = seq.offset - 1;
1295        if (seq.literal_length == 0) {
1296            // "There is an exception though, when current sequence's
1297            // literals length is 0. In this case, repeated offsets are
1298            // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
1299            // Repeated_Offset2 becomes Repeated_Offset3, and
1300            // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
1301            idx++;
1302        }
1303
1304        if (idx == 0) {
1305            offset = offset_hist[0];
1306        } else {
1307            // If idx == 3 then literal length was 0 and the offset was 3,
1308            // as per the exception listed above
1309            offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
1310
1311            // If idx == 1 we don't need to modify offset_hist[2], since
1312            // we're using the second-most recent code
1313            if (idx > 1) {
1314                offset_hist[2] = offset_hist[1];
1315            }
1316            offset_hist[1] = offset_hist[0];
1317            offset_hist[0] = offset;
1318        }
1319    } else {
1320        // When it's not a repeat offset:
1321        // "if (Offset_Value > 3) offset = Offset_Value - 3;"
1322        offset = seq.offset - 3;
1323
1324        // Shift back history
1325        offset_hist[2] = offset_hist[1];
1326        offset_hist[1] = offset_hist[0];
1327        offset_hist[0] = offset;
1328    }
1329    return offset;
1330}
1331
1332static void execute_match_copy(frame_context_t *const ctx, size_t offset,
1333                              size_t match_length, size_t total_output,
1334                              ostream_t *const out) {
1335    u8 *write_ptr = IO_get_write_ptr(out, match_length);
1336    if (total_output <= ctx->header.window_size) {
1337        // In this case offset might go back into the dictionary
1338        if (offset > total_output + ctx->dict_content_len) {
1339            // The offset goes beyond even the dictionary
1340            CORRUPTION();
1341        }
1342
1343        if (offset > total_output) {
1344            // "The rest of the dictionary is its content. The content act
1345            // as a "past" in front of data to compress or decompress, so it
1346            // can be referenced in sequence commands."
1347            const size_t dict_copy =
1348                MIN(offset - total_output, match_length);
1349            const size_t dict_offset =
1350                ctx->dict_content_len - (offset - total_output);
1351
1352            memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
1353            write_ptr += dict_copy;
1354            match_length -= dict_copy;
1355        }
1356    } else if (offset > ctx->header.window_size) {
1357        CORRUPTION();
1358    }
1359
1360    // We must copy byte by byte because the match length might be larger
1361    // than the offset
1362    // ex: if the output so far was "abc", a command with offset=3 and
1363    // match_length=6 would produce "abcabcabc" as the new output
1364    for (size_t j = 0; j < match_length; j++) {
1365        *write_ptr = *(write_ptr - offset);
1366        write_ptr++;
1367    }
1368}
1369/******* END SEQUENCE EXECUTION ***********************************************/
1370
1371/******* OUTPUT SIZE COUNTING *************************************************/
1372/// Get the decompressed size of an input stream so memory can be allocated in
1373/// advance.
1374/// This implementation assumes `src` points to a single ZSTD-compressed frame
1375size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
1376    istream_t in = IO_make_istream(src, src_len);
1377
1378    // get decompressed size from ZSTD frame header
1379    {
1380        const u32 magic_number = (u32)IO_read_bits(&in, 32);
1381
1382        if (magic_number == ZSTD_MAGIC_NUMBER) {
1383            // ZSTD frame
1384            frame_header_t header;
1385            parse_frame_header(&header, &in);
1386
1387            if (header.frame_content_size == 0 && !header.single_segment_flag) {
1388                // Content size not provided, we can't tell
1389                return (size_t)-1;
1390            }
1391
1392            return header.frame_content_size;
1393        } else {
1394            // not a real frame or skippable frame
1395            ERROR("ZSTD frame magic number did not match");
1396        }
1397    }
1398}
1399/******* END OUTPUT SIZE COUNTING *********************************************/
1400
1401/******* DICTIONARY PARSING ***************************************************/
1402dictionary_t* create_dictionary() {
1403    dictionary_t* const dict = calloc(1, sizeof(dictionary_t));
1404    if (!dict) {
1405        BAD_ALLOC();
1406    }
1407    return dict;
1408}
1409
1410/// Free an allocated dictionary
1411void free_dictionary(dictionary_t *const dict) {
1412    HUF_free_dtable(&dict->literals_dtable);
1413    FSE_free_dtable(&dict->ll_dtable);
1414    FSE_free_dtable(&dict->of_dtable);
1415    FSE_free_dtable(&dict->ml_dtable);
1416
1417    free(dict->content);
1418
1419    memset(dict, 0, sizeof(dictionary_t));
1420
1421    free(dict);
1422}
1423
1424
1425#if !defined(ZDEC_NO_DICTIONARY)
1426#define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
1427#define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
1428
1429static void init_dictionary_content(dictionary_t *const dict,
1430                                    istream_t *const in);
1431
1432void parse_dictionary(dictionary_t *const dict, const void *src,
1433                             size_t src_len) {
1434    const u8 *byte_src = (const u8 *)src;
1435    memset(dict, 0, sizeof(dictionary_t));
1436    if (src == NULL) { /* cannot initialize dictionary with null src */
1437        NULL_SRC();
1438    }
1439    if (src_len < 8) {
1440        DICT_SIZE_ERROR();
1441    }
1442
1443    istream_t in = IO_make_istream(byte_src, src_len);
1444
1445    const u32 magic_number = IO_read_bits(&in, 32);
1446    if (magic_number != 0xEC30A437) {
1447        // raw content dict
1448        IO_rewind_bits(&in, 32);
1449        init_dictionary_content(dict, &in);
1450        return;
1451    }
1452
1453    dict->dictionary_id = IO_read_bits(&in, 32);
1454
1455    // "Entropy_Tables : following the same format as the tables in compressed
1456    // blocks. They are stored in following order : Huffman tables for literals,
1457    // FSE table for offsets, FSE table for match lengths, and FSE table for
1458    // literals lengths. It's finally followed by 3 offset values, populating
1459    // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
1460    // little-endian each, for a total of 12 bytes. Each recent offset must have
1461    // a value < dictionary size."
1462    decode_huf_table(&dict->literals_dtable, &in);
1463    decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
1464    decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
1465    decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
1466
1467    // Read in the previous offset history
1468    dict->previous_offsets[0] = IO_read_bits(&in, 32);
1469    dict->previous_offsets[1] = IO_read_bits(&in, 32);
1470    dict->previous_offsets[2] = IO_read_bits(&in, 32);
1471
1472    // Ensure the provided offsets aren't too large
1473    // "Each recent offset must have a value < dictionary size."
1474    for (int i = 0; i < 3; i++) {
1475        if (dict->previous_offsets[i] > src_len) {
1476            ERROR("Dictionary corrupted");
1477        }
1478    }
1479
1480    // "Content : The rest of the dictionary is its content. The content act as
1481    // a "past" in front of data to compress or decompress, so it can be
1482    // referenced in sequence commands."
1483    init_dictionary_content(dict, &in);
1484}
1485
1486static void init_dictionary_content(dictionary_t *const dict,
1487                                    istream_t *const in) {
1488    // Copy in the content
1489    dict->content_size = IO_istream_len(in);
1490    dict->content = malloc(dict->content_size);
1491    if (!dict->content) {
1492        BAD_ALLOC();
1493    }
1494
1495    const u8 *const content = IO_get_read_ptr(in, dict->content_size);
1496
1497    memcpy(dict->content, content, dict->content_size);
1498}
1499
1500static void HUF_copy_dtable(HUF_dtable *const dst,
1501                            const HUF_dtable *const src) {
1502    if (src->max_bits == 0) {
1503        memset(dst, 0, sizeof(HUF_dtable));
1504        return;
1505    }
1506
1507    const size_t size = (size_t)1 << src->max_bits;
1508    dst->max_bits = src->max_bits;
1509
1510    dst->symbols = malloc(size);
1511    dst->num_bits = malloc(size);
1512    if (!dst->symbols || !dst->num_bits) {
1513        BAD_ALLOC();
1514    }
1515
1516    memcpy(dst->symbols, src->symbols, size);
1517    memcpy(dst->num_bits, src->num_bits, size);
1518}
1519
1520static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
1521    if (src->accuracy_log == 0) {
1522        memset(dst, 0, sizeof(FSE_dtable));
1523        return;
1524    }
1525
1526    size_t size = (size_t)1 << src->accuracy_log;
1527    dst->accuracy_log = src->accuracy_log;
1528
1529    dst->symbols = malloc(size);
1530    dst->num_bits = malloc(size);
1531    dst->new_state_base = malloc(size * sizeof(u16));
1532    if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
1533        BAD_ALLOC();
1534    }
1535
1536    memcpy(dst->symbols, src->symbols, size);
1537    memcpy(dst->num_bits, src->num_bits, size);
1538    memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
1539}
1540
1541/// A dictionary acts as initializing values for the frame context before
1542/// decompression, so we implement it by applying it's predetermined
1543/// tables and content to the context before beginning decompression
1544static void frame_context_apply_dict(frame_context_t *const ctx,
1545                                     const dictionary_t *const dict) {
1546    // If the content pointer is NULL then it must be an empty dict
1547    if (!dict || !dict->content)
1548        return;
1549
1550    // If the requested dictionary_id is non-zero, the correct dictionary must
1551    // be present
1552    if (ctx->header.dictionary_id != 0 &&
1553        ctx->header.dictionary_id != dict->dictionary_id) {
1554        ERROR("Wrong dictionary provided");
1555    }
1556
1557    // Copy the dict content to the context for references during sequence
1558    // execution
1559    ctx->dict_content = dict->content;
1560    ctx->dict_content_len = dict->content_size;
1561
1562    // If it's a formatted dict copy the precomputed tables in so they can
1563    // be used in the table repeat modes
1564    if (dict->dictionary_id != 0) {
1565        // Deep copy the entropy tables so they can be freed independently of
1566        // the dictionary struct
1567        HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
1568        FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
1569        FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
1570        FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
1571
1572        // Copy the repeated offsets
1573        memcpy(ctx->previous_offsets, dict->previous_offsets,
1574               sizeof(ctx->previous_offsets));
1575    }
1576}
1577
1578#else  // ZDEC_NO_DICTIONARY is defined
1579
1580static void frame_context_apply_dict(frame_context_t *const ctx,
1581                                     const dictionary_t *const dict) {
1582    (void)ctx;
1583    if (dict && dict->content) ERROR("dictionary not supported");
1584}
1585
1586#endif
1587/******* END DICTIONARY PARSING ***********************************************/
1588
1589/******* IO STREAM OPERATIONS *************************************************/
1590
1591/// Reads `num` bits from a bitstream, and updates the internal offset
1592static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
1593    if (num_bits > 64 || num_bits <= 0) {
1594        ERROR("Attempt to read an invalid number of bits");
1595    }
1596
1597    const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
1598    const size_t full_bytes = (num_bits + in->bit_offset) / 8;
1599    if (bytes > in->len) {
1600        INP_SIZE();
1601    }
1602
1603    const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
1604
1605    in->bit_offset = (num_bits + in->bit_offset) % 8;
1606    in->ptr += full_bytes;
1607    in->len -= full_bytes;
1608
1609    return result;
1610}
1611
1612/// If a non-zero number of bits have been read from the current byte, advance
1613/// the offset to the next byte
1614static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
1615    if (num_bits < 0) {
1616        ERROR("Attempting to rewind stream by a negative number of bits");
1617    }
1618
1619    // move the offset back by `num_bits` bits
1620    const int new_offset = in->bit_offset - num_bits;
1621    // determine the number of whole bytes we have to rewind, rounding up to an
1622    // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
1623    const i64 bytes = -(new_offset - 7) / 8;
1624
1625    in->ptr -= bytes;
1626    in->len += bytes;
1627    // make sure the resulting `bit_offset` is positive, as mod in C does not
1628    // convert numbers from negative to positive (e.g. -22 % 8 == -6)
1629    in->bit_offset = ((new_offset % 8) + 8) % 8;
1630}
1631
1632/// If the remaining bits in a byte will be unused, advance to the end of the
1633/// byte
1634static inline void IO_align_stream(istream_t *const in) {
1635    if (in->bit_offset != 0) {
1636        if (in->len == 0) {
1637            INP_SIZE();
1638        }
1639        in->ptr++;
1640        in->len--;
1641        in->bit_offset = 0;
1642    }
1643}
1644
1645/// Write the given byte into the output stream
1646static inline void IO_write_byte(ostream_t *const out, u8 symb) {
1647    if (out->len == 0) {
1648        OUT_SIZE();
1649    }
1650
1651    out->ptr[0] = symb;
1652    out->ptr++;
1653    out->len--;
1654}
1655
1656/// Returns the number of bytes left to be read in this stream.  The stream must
1657/// be byte aligned.
1658static inline size_t IO_istream_len(const istream_t *const in) {
1659    return in->len;
1660}
1661
1662/// Returns a pointer where `len` bytes can be read, and advances the internal
1663/// state.  The stream must be byte aligned.
1664static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
1665    if (len > in->len) {
1666        INP_SIZE();
1667    }
1668    if (in->bit_offset != 0) {
1669        ERROR("Attempting to operate on a non-byte aligned stream");
1670    }
1671    const u8 *const ptr = in->ptr;
1672    in->ptr += len;
1673    in->len -= len;
1674
1675    return ptr;
1676}
1677/// Returns a pointer to write `len` bytes to, and advances the internal state
1678static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
1679    if (len > out->len) {
1680        OUT_SIZE();
1681    }
1682    u8 *const ptr = out->ptr;
1683    out->ptr += len;
1684    out->len -= len;
1685
1686    return ptr;
1687}
1688
1689/// Advance the inner state by `len` bytes
1690static inline void IO_advance_input(istream_t *const in, size_t len) {
1691    if (len > in->len) {
1692         INP_SIZE();
1693    }
1694    if (in->bit_offset != 0) {
1695        ERROR("Attempting to operate on a non-byte aligned stream");
1696    }
1697
1698    in->ptr += len;
1699    in->len -= len;
1700}
1701
1702/// Returns an `ostream_t` constructed from the given pointer and length
1703static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
1704    return (ostream_t) { out, len };
1705}
1706
1707/// Returns an `istream_t` constructed from the given pointer and length
1708static inline istream_t IO_make_istream(const u8 *in, size_t len) {
1709    return (istream_t) { in, len, 0 };
1710}
1711
1712/// Returns an `istream_t` with the same base as `in`, and length `len`
1713/// Then, advance `in` to account for the consumed bytes
1714/// `in` must be byte aligned
1715static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
1716    // Consume `len` bytes of the parent stream
1717    const u8 *const ptr = IO_get_read_ptr(in, len);
1718
1719    // Make a substream using the pointer to those `len` bytes
1720    return IO_make_istream(ptr, len);
1721}
1722/******* END IO STREAM OPERATIONS *********************************************/
1723
1724/******* BITSTREAM OPERATIONS *************************************************/
1725/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
1726static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1727                               const size_t offset) {
1728    if (num_bits > 64) {
1729        ERROR("Attempt to read an invalid number of bits");
1730    }
1731
1732    // Skip over bytes that aren't in range
1733    src += offset / 8;
1734    size_t bit_offset = offset % 8;
1735    u64 res = 0;
1736
1737    int shift = 0;
1738    int left = num_bits;
1739    while (left > 0) {
1740        u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
1741        // Read the next byte, shift it to account for the offset, and then mask
1742        // out the top part if we don't need all the bits
1743        res += (((u64)*src++ >> bit_offset) & mask) << shift;
1744        shift += 8 - bit_offset;
1745        left -= 8 - bit_offset;
1746        bit_offset = 0;
1747    }
1748
1749    return res;
1750}
1751
1752/// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
1753/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1754/// `src + offset`.  If the offset becomes negative, the extra bits at the
1755/// bottom are filled in with `0` bits instead of reading from before `src`.
1756static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
1757                                   i64 *const offset) {
1758    *offset = *offset - bits;
1759    size_t actual_off = *offset;
1760    size_t actual_bits = bits;
1761    // Don't actually read bits from before the start of src, so if `*offset <
1762    // 0` fix actual_off and actual_bits to reflect the quantity to read
1763    if (*offset < 0) {
1764        actual_bits += *offset;
1765        actual_off = 0;
1766    }
1767    u64 res = read_bits_LE(src, actual_bits, actual_off);
1768
1769    if (*offset < 0) {
1770        // Fill in the bottom "overflowed" bits with 0's
1771        res = -*offset >= 64 ? 0 : (res << -*offset);
1772    }
1773    return res;
1774}
1775/******* END BITSTREAM OPERATIONS *********************************************/
1776
1777/******* BIT COUNTING OPERATIONS **********************************************/
1778/// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
1779/// `num`, or `-1` if `num == 0`.
1780static inline int highest_set_bit(const u64 num) {
1781    for (int i = 63; i >= 0; i--) {
1782        if (((u64)1 << i) <= num) {
1783            return i;
1784        }
1785    }
1786    return -1;
1787}
1788/******* END BIT COUNTING OPERATIONS ******************************************/
1789
1790/******* HUFFMAN PRIMITIVES ***************************************************/
1791static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1792                                   u16 *const state, const u8 *const src,
1793                                   i64 *const offset) {
1794    // Look up the symbol and number of bits to read
1795    const u8 symb = dtable->symbols[*state];
1796    const u8 bits = dtable->num_bits[*state];
1797    const u16 rest = STREAM_read_bits(src, bits, offset);
1798    // Shift `bits` bits out of the state, keeping the low order bits that
1799    // weren't necessary to determine this symbol.  Then add in the new bits
1800    // read from the stream.
1801    *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
1802
1803    return symb;
1804}
1805
1806static inline void HUF_init_state(const HUF_dtable *const dtable,
1807                                  u16 *const state, const u8 *const src,
1808                                  i64 *const offset) {
1809    // Read in a full `dtable->max_bits` bits to initialize the state
1810    const u8 bits = dtable->max_bits;
1811    *state = STREAM_read_bits(src, bits, offset);
1812}
1813
1814static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1815                                     ostream_t *const out,
1816                                     istream_t *const in) {
1817    const size_t len = IO_istream_len(in);
1818    if (len == 0) {
1819        INP_SIZE();
1820    }
1821    const u8 *const src = IO_get_read_ptr(in, len);
1822
1823    // "Each bitstream must be read backward, that is starting from the end down
1824    // to the beginning. Therefore it's necessary to know the size of each
1825    // bitstream.
1826    //
1827    // It's also necessary to know exactly which bit is the latest. This is
1828    // detected by a final bit flag : the highest bit of latest byte is a
1829    // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
1830    // final-bit-flag itself is not part of the useful bitstream. Hence, the
1831    // last byte contains between 0 and 7 useful bits."
1832    const int padding = 8 - highest_set_bit(src[len - 1]);
1833
1834    // Offset starts at the end because HUF streams are read backwards
1835    i64 bit_offset = len * 8 - padding;
1836    u16 state;
1837
1838    HUF_init_state(dtable, &state, src, &bit_offset);
1839
1840    size_t symbols_written = 0;
1841    while (bit_offset > -dtable->max_bits) {
1842        // Iterate over the stream, decoding one symbol at a time
1843        IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
1844        symbols_written++;
1845    }
1846    // "The process continues up to reading the required number of symbols per
1847    // stream. If a bitstream is not entirely and exactly consumed, hence
1848    // reaching exactly its beginning position with all bits consumed, the
1849    // decoding process is considered faulty."
1850
1851    // When all symbols have been decoded, the final state value shouldn't have
1852    // any data from the stream, so it should have "read" dtable->max_bits from
1853    // before the start of `src`
1854    // Therefore `offset`, the edge to start reading new bits at, should be
1855    // dtable->max_bits before the start of the stream
1856    if (bit_offset != -dtable->max_bits) {
1857        CORRUPTION();
1858    }
1859
1860    return symbols_written;
1861}
1862
1863static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1864                                     ostream_t *const out, istream_t *const in) {
1865    // "Compressed size is provided explicitly : in the 4-streams variant,
1866    // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
1867    // value represents the compressed size of one stream, in order. The last
1868    // stream size is deducted from total compressed size and from previously
1869    // decoded stream sizes"
1870    const size_t csize1 = IO_read_bits(in, 16);
1871    const size_t csize2 = IO_read_bits(in, 16);
1872    const size_t csize3 = IO_read_bits(in, 16);
1873
1874    istream_t in1 = IO_make_sub_istream(in, csize1);
1875    istream_t in2 = IO_make_sub_istream(in, csize2);
1876    istream_t in3 = IO_make_sub_istream(in, csize3);
1877    istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
1878
1879    size_t total_output = 0;
1880    // Decode each stream independently for simplicity
1881    // If we wanted to we could decode all 4 at the same time for speed,
1882    // utilizing more execution units
1883    total_output += HUF_decompress_1stream(dtable, out, &in1);
1884    total_output += HUF_decompress_1stream(dtable, out, &in2);
1885    total_output += HUF_decompress_1stream(dtable, out, &in3);
1886    total_output += HUF_decompress_1stream(dtable, out, &in4);
1887
1888    return total_output;
1889}
1890
1891/// Initializes a Huffman table using canonical Huffman codes
1892/// For more explanation on canonical Huffman codes see
1893/// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
1894/// Codes within a level are allocated in symbol order (i.e. smaller symbols get
1895/// earlier codes)
1896static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1897                            const int num_symbs) {
1898    memset(table, 0, sizeof(HUF_dtable));
1899    if (num_symbs > HUF_MAX_SYMBS) {
1900        ERROR("Too many symbols for Huffman");
1901    }
1902
1903    u8 max_bits = 0;
1904    u16 rank_count[HUF_MAX_BITS + 1];
1905    memset(rank_count, 0, sizeof(rank_count));
1906
1907    // Count the number of symbols for each number of bits, and determine the
1908    // depth of the tree
1909    for (int i = 0; i < num_symbs; i++) {
1910        if (bits[i] > HUF_MAX_BITS) {
1911            ERROR("Huffman table depth too large");
1912        }
1913        max_bits = MAX(max_bits, bits[i]);
1914        rank_count[bits[i]]++;
1915    }
1916
1917    const size_t table_size = 1 << max_bits;
1918    table->max_bits = max_bits;
1919    table->symbols = malloc(table_size);
1920    table->num_bits = malloc(table_size);
1921
1922    if (!table->symbols || !table->num_bits) {
1923        free(table->symbols);
1924        free(table->num_bits);
1925        BAD_ALLOC();
1926    }
1927
1928    // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
1929    // order. Symbols with a Weight of zero are removed. Then, starting from
1930    // lowest weight, prefix codes are distributed in order."
1931
1932    u32 rank_idx[HUF_MAX_BITS + 1];
1933    // Initialize the starting codes for each rank (number of bits)
1934    rank_idx[max_bits] = 0;
1935    for (int i = max_bits; i >= 1; i--) {
1936        rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
1937        // The entire range takes the same number of bits so we can memset it
1938        memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
1939    }
1940
1941    if (rank_idx[0] != table_size) {
1942        CORRUPTION();
1943    }
1944
1945    // Allocate codes and fill in the table
1946    for (int i = 0; i < num_symbs; i++) {
1947        if (bits[i] != 0) {
1948            // Allocate a code for this symbol and set its range in the table
1949            const u16 code = rank_idx[bits[i]];
1950            // Since the code doesn't care about the bottom `max_bits - bits[i]`
1951            // bits of state, it gets a range that spans all possible values of
1952            // the lower bits
1953            const u16 len = 1 << (max_bits - bits[i]);
1954            memset(&table->symbols[code], i, len);
1955            rank_idx[bits[i]] += len;
1956        }
1957    }
1958}
1959
1960static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1961                                         const u8 *const weights,
1962                                         const int num_symbs) {
1963    // +1 because the last weight is not transmitted in the header
1964    if (num_symbs + 1 > HUF_MAX_SYMBS) {
1965        ERROR("Too many symbols for Huffman");
1966    }
1967
1968    u8 bits[HUF_MAX_SYMBS];
1969
1970    u64 weight_sum = 0;
1971    for (int i = 0; i < num_symbs; i++) {
1972        // Weights are in the same range as bit count
1973        if (weights[i] > HUF_MAX_BITS) {
1974            CORRUPTION();
1975        }
1976        weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
1977    }
1978
1979    // Find the first power of 2 larger than the sum
1980    const int max_bits = highest_set_bit(weight_sum) + 1;
1981    const u64 left_over = ((u64)1 << max_bits) - weight_sum;
1982    // If the left over isn't a power of 2, the weights are invalid
1983    if (left_over & (left_over - 1)) {
1984        CORRUPTION();
1985    }
1986
1987    // left_over is used to find the last weight as it's not transmitted
1988    // by inverting 2^(weight - 1) we can determine the value of last_weight
1989    const int last_weight = highest_set_bit(left_over) + 1;
1990
1991    for (int i = 0; i < num_symbs; i++) {
1992        // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
1993        bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
1994    }
1995    bits[num_symbs] =
1996        max_bits + 1 - last_weight; // Last weight is always non-zero
1997
1998    HUF_init_dtable(table, bits, num_symbs + 1);
1999}
2000
2001static void HUF_free_dtable(HUF_dtable *const dtable) {
2002    free(dtable->symbols);
2003    free(dtable->num_bits);
2004    memset(dtable, 0, sizeof(HUF_dtable));
2005}
2006/******* END HUFFMAN PRIMITIVES ***********************************************/
2007
2008/******* FSE PRIMITIVES *******************************************************/
2009/// For more description of FSE see
2010/// https://github.com/Cyan4973/FiniteStateEntropy/
2011
2012/// Allow a symbol to be decoded without updating state
2013static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
2014                                 const u16 state) {
2015    return dtable->symbols[state];
2016}
2017
2018/// Consumes bits from the input and uses the current state to determine the
2019/// next state
2020static inline void FSE_update_state(const FSE_dtable *const dtable,
2021                                    u16 *const state, const u8 *const src,
2022                                    i64 *const offset) {
2023    const u8 bits = dtable->num_bits[*state];
2024    const u16 rest = STREAM_read_bits(src, bits, offset);
2025    *state = dtable->new_state_base[*state] + rest;
2026}
2027
2028/// Decodes a single FSE symbol and updates the offset
2029static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
2030                                   u16 *const state, const u8 *const src,
2031                                   i64 *const offset) {
2032    const u8 symb = FSE_peek_symbol(dtable, *state);
2033    FSE_update_state(dtable, state, src, offset);
2034    return symb;
2035}
2036
2037static inline void FSE_init_state(const FSE_dtable *const dtable,
2038                                  u16 *const state, const u8 *const src,
2039                                  i64 *const offset) {
2040    // Read in a full `accuracy_log` bits to initialize the state
2041    const u8 bits = dtable->accuracy_log;
2042    *state = STREAM_read_bits(src, bits, offset);
2043}
2044
2045static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2046                                          ostream_t *const out,
2047                                          istream_t *const in) {
2048    const size_t len = IO_istream_len(in);
2049    if (len == 0) {
2050        INP_SIZE();
2051    }
2052    const u8 *const src = IO_get_read_ptr(in, len);
2053
2054    // "Each bitstream must be read backward, that is starting from the end down
2055    // to the beginning. Therefore it's necessary to know the size of each
2056    // bitstream.
2057    //
2058    // It's also necessary to know exactly which bit is the latest. This is
2059    // detected by a final bit flag : the highest bit of latest byte is a
2060    // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
2061    // final-bit-flag itself is not part of the useful bitstream. Hence, the
2062    // last byte contains between 0 and 7 useful bits."
2063    const int padding = 8 - highest_set_bit(src[len - 1]);
2064    i64 offset = len * 8 - padding;
2065
2066    u16 state1, state2;
2067    // "The first state (State1) encodes the even indexed symbols, and the
2068    // second (State2) encodes the odd indexes. State1 is initialized first, and
2069    // then State2, and they take turns decoding a single symbol and updating
2070    // their state."
2071    FSE_init_state(dtable, &state1, src, &offset);
2072    FSE_init_state(dtable, &state2, src, &offset);
2073
2074    // Decode until we overflow the stream
2075    // Since we decode in reverse order, overflowing the stream is offset going
2076    // negative
2077    size_t symbols_written = 0;
2078    while (1) {
2079        // "The number of symbols to decode is determined by tracking bitStream
2080        // overflow condition: If updating state after decoding a symbol would
2081        // require more bits than remain in the stream, it is assumed the extra
2082        // bits are 0. Then, the symbols for each of the final states are
2083        // decoded and the process is complete."
2084        IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
2085        symbols_written++;
2086        if (offset < 0) {
2087            // There's still a symbol to decode in state2
2088            IO_write_byte(out, FSE_peek_symbol(dtable, state2));
2089            symbols_written++;
2090            break;
2091        }
2092
2093        IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
2094        symbols_written++;
2095        if (offset < 0) {
2096            // There's still a symbol to decode in state1
2097            IO_write_byte(out, FSE_peek_symbol(dtable, state1));
2098            symbols_written++;
2099            break;
2100        }
2101    }
2102
2103    return symbols_written;
2104}
2105
2106static void FSE_init_dtable(FSE_dtable *const dtable,
2107                            const i16 *const norm_freqs, const int num_symbs,
2108                            const int accuracy_log) {
2109    if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
2110        ERROR("FSE accuracy too large");
2111    }
2112    if (num_symbs > FSE_MAX_SYMBS) {
2113        ERROR("Too many symbols for FSE");
2114    }
2115
2116    dtable->accuracy_log = accuracy_log;
2117
2118    const size_t size = (size_t)1 << accuracy_log;
2119    dtable->symbols = malloc(size * sizeof(u8));
2120    dtable->num_bits = malloc(size * sizeof(u8));
2121    dtable->new_state_base = malloc(size * sizeof(u16));
2122
2123    if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2124        BAD_ALLOC();
2125    }
2126
2127    // Used to determine how many bits need to be read for each state,
2128    // and where the destination range should start
2129    // Needs to be u16 because max value is 2 * max number of symbols,
2130    // which can be larger than a byte can store
2131    u16 state_desc[FSE_MAX_SYMBS];
2132
2133    // "Symbols are scanned in their natural order for "less than 1"
2134    // probabilities. Symbols with this probability are being attributed a
2135    // single cell, starting from the end of the table. These symbols define a
2136    // full state reset, reading Accuracy_Log bits."
2137    int high_threshold = size;
2138    for (int s = 0; s < num_symbs; s++) {
2139        // Scan for low probability symbols to put at the top
2140        if (norm_freqs[s] == -1) {
2141            dtable->symbols[--high_threshold] = s;
2142            state_desc[s] = 1;
2143        }
2144    }
2145
2146    // "All remaining symbols are sorted in their natural order. Starting from
2147    // symbol 0 and table position 0, each symbol gets attributed as many cells
2148    // as its probability. Cell allocation is spread, not linear."
2149    // Place the rest in the table
2150    const u16 step = (size >> 1) + (size >> 3) + 3;
2151    const u16 mask = size - 1;
2152    u16 pos = 0;
2153    for (int s = 0; s < num_symbs; s++) {
2154        if (norm_freqs[s] <= 0) {
2155            continue;
2156        }
2157
2158        state_desc[s] = norm_freqs[s];
2159
2160        for (int i = 0; i < norm_freqs[s]; i++) {
2161            // Give `norm_freqs[s]` states to symbol s
2162            dtable->symbols[pos] = s;
2163            // "A position is skipped if already occupied, typically by a "less
2164            // than 1" probability symbol."
2165            do {
2166                pos = (pos + step) & mask;
2167            } while (pos >=
2168                     high_threshold);
2169            // Note: no other collision checking is necessary as `step` is
2170            // coprime to `size`, so the cycle will visit each position exactly
2171            // once
2172        }
2173    }
2174    if (pos != 0) {
2175        CORRUPTION();
2176    }
2177
2178    // Now we can fill baseline and num bits
2179    for (size_t i = 0; i < size; i++) {
2180        u8 symbol = dtable->symbols[i];
2181        u16 next_state_desc = state_desc[symbol]++;
2182        // Fills in the table appropriately, next_state_desc increases by symbol
2183        // over time, decreasing number of bits
2184        dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
2185        // Baseline increases until the bit threshold is passed, at which point
2186        // it resets to 0
2187        dtable->new_state_base[i] =
2188            ((u16)next_state_desc << dtable->num_bits[i]) - size;
2189    }
2190}
2191
2192/// Decode an FSE header as defined in the Zstandard format specification and
2193/// use the decoded frequencies to initialize a decoding table.
2194static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2195                                const int max_accuracy_log) {
2196    // "An FSE distribution table describes the probabilities of all symbols
2197    // from 0 to the last present one (included) on a normalized scale of 1 <<
2198    // Accuracy_Log .
2199    //
2200    // It's a bitstream which is read forward, in little-endian fashion. It's
2201    // not necessary to know its exact size, since it will be discovered and
2202    // reported by the decoding process.
2203    if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
2204        ERROR("FSE accuracy too large");
2205    }
2206
2207    // The bitstream starts by reporting on which scale it operates.
2208    // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
2209    // and match lengths is 9, and for offsets is 8. Higher values are
2210    // considered errors."
2211    const int accuracy_log = 5 + IO_read_bits(in, 4);
2212    if (accuracy_log > max_accuracy_log) {
2213        ERROR("FSE accuracy too large");
2214    }
2215
2216    // "Then follows each symbol value, from 0 to last present one. The number
2217    // of bits used by each field is variable. It depends on :
2218    //
2219    // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
2220    // and presuming 100 probabilities points have already been distributed, the
2221    // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
2222    // Therefore, it must read log2sup(156) == 8 bits.
2223    //
2224    // Value decoded : small values use 1 less bit : example : Presuming values
2225    // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
2226    // in an 8-bits field. They are used this way : first 99 values (hence from
2227    // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
2228
2229    i32 remaining = 1 << accuracy_log;
2230    i16 frequencies[FSE_MAX_SYMBS];
2231
2232    int symb = 0;
2233    while (remaining > 0 && symb < FSE_MAX_SYMBS) {
2234        // Log of the number of possible values we could read
2235        int bits = highest_set_bit(remaining + 1) + 1;
2236
2237        u16 val = IO_read_bits(in, bits);
2238
2239        // Try to mask out the lower bits to see if it qualifies for the "small
2240        // value" threshold
2241        const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
2242        const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
2243
2244        if ((val & lower_mask) < threshold) {
2245            IO_rewind_bits(in, 1);
2246            val = val & lower_mask;
2247        } else if (val > lower_mask) {
2248            val = val - threshold;
2249        }
2250
2251        // "Probability is obtained from Value decoded by following formula :
2252        // Proba = value - 1"
2253        const i16 proba = (i16)val - 1;
2254
2255        // "It means value 0 becomes negative probability -1. -1 is a special
2256        // probability, which means "less than 1". Its effect on distribution
2257        // table is described in next paragraph. For the purpose of calculating
2258        // cumulated distribution, it counts as one."
2259        remaining -= proba < 0 ? -proba : proba;
2260
2261        frequencies[symb] = proba;
2262        symb++;
2263
2264        // "When a symbol has a probability of zero, it is followed by a 2-bits
2265        // repeat flag. This repeat flag tells how many probabilities of zeroes
2266        // follow the current one. It provides a number ranging from 0 to 3. If
2267        // it is a 3, another 2-bits repeat flag follows, and so on."
2268        if (proba == 0) {
2269            // Read the next two bits to see how many more 0s
2270            int repeat = IO_read_bits(in, 2);
2271
2272            while (1) {
2273                for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
2274                    frequencies[symb++] = 0;
2275                }
2276                if (repeat == 3) {
2277                    repeat = IO_read_bits(in, 2);
2278                } else {
2279                    break;
2280                }
2281            }
2282        }
2283    }
2284    IO_align_stream(in);
2285
2286    // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
2287    // is complete. If the last symbol makes cumulated total go above 1 <<
2288    // Accuracy_Log, distribution is considered corrupted."
2289    if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
2290        CORRUPTION();
2291    }
2292
2293    // Initialize the decoding table using the determined weights
2294    FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
2295}
2296
2297static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
2298    dtable->symbols = malloc(sizeof(u8));
2299    dtable->num_bits = malloc(sizeof(u8));
2300    dtable->new_state_base = malloc(sizeof(u16));
2301
2302    if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2303        BAD_ALLOC();
2304    }
2305
2306    // This setup will always have a state of 0, always return symbol `symb`,
2307    // and never consume any bits
2308    dtable->symbols[0] = symb;
2309    dtable->num_bits[0] = 0;
2310    dtable->new_state_base[0] = 0;
2311    dtable->accuracy_log = 0;
2312}
2313
2314static void FSE_free_dtable(FSE_dtable *const dtable) {
2315    free(dtable->symbols);
2316    free(dtable->num_bits);
2317    free(dtable->new_state_base);
2318    memset(dtable, 0, sizeof(FSE_dtable));
2319}
2320/******* END FSE PRIMITIVES ***************************************************/
2321