diff options
Diffstat (limited to 'doc/educational_decoder/zstd_decompress.c')
-rw-r--r-- | doc/educational_decoder/zstd_decompress.c | 415 |
1 files changed, 180 insertions, 235 deletions
diff --git a/doc/educational_decoder/zstd_decompress.c b/doc/educational_decoder/zstd_decompress.c index 7c8d8114d401..af10db528d2a 100644 --- a/doc/educational_decoder/zstd_decompress.c +++ b/doc/educational_decoder/zstd_decompress.c @@ -14,21 +14,7 @@ #include <stdio.h> #include <stdlib.h> #include <string.h> - -/// Zstandard decompression functions. -/// `dst` must point to a space at least as large as the reconstructed output. -size_t ZSTD_decompress(void *const dst, const size_t dst_len, - const void *const src, const size_t src_len); -/// If `dict != NULL` and `dict_len >= 8`, does the same thing as -/// `ZSTD_decompress` but uses the provided dict -size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, - const void *const src, const size_t src_len, - const void *const dict, const size_t dict_len); - -/// Get the decompressed size of an input stream so memory can be allocated in -/// advance -/// Returns -1 if the size can't be determined -size_t ZSTD_get_decompressed_size(const void *const src, const size_t src_len); +#include "zstd_decompress.h" /******* UTILITY MACROS AND TYPES *********************************************/ // Max block size decompressed size is 128 KB and literal blocks can't be @@ -108,10 +94,10 @@ static inline size_t IO_istream_len(const istream_t *const in); /// Advances the stream by `len` bytes, and returns a pointer to the chunk that /// was skipped. The stream must be byte aligned. -static inline const u8 *IO_read_bytes(istream_t *const in, size_t len); +static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len); /// Advances the stream by `len` bytes, and returns a pointer to the chunk that /// was skipped so it can be written to. -static inline u8 *IO_write_bytes(ostream_t *const out, size_t len); +static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len); /// Advance the inner state by `len` bytes. The stream must be byte aligned. static inline void IO_advance_input(istream_t *const in, size_t len); @@ -307,7 +293,7 @@ typedef struct { /// The decoded contents of a dictionary so that it doesn't have to be repeated /// for each frame that uses it -typedef struct { +struct dictionary_s { // Entropy tables HUF_dtable literals_dtable; FSE_dtable ll_dtable; @@ -322,7 +308,7 @@ typedef struct { u64 previous_offsets[3]; u32 dictionary_id; -} dictionary_t; +}; /// A tuple containing the parts necessary to decode and execute a ZSTD sequence /// command @@ -367,27 +353,36 @@ static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, const sequence_command_t *const sequences, const size_t num_sequences); -// Parse a provided dictionary blob for use in decompression -static void parse_dictionary(dictionary_t *const dict, const u8 *src, - size_t src_len); -static void free_dictionary(dictionary_t *const dict); +// Copies literals and returns the total literal length that was copied +static u32 copy_literals(const size_t seq, istream_t *litstream, + ostream_t *const out); + +// Given an offset code from a sequence command (either an actual offset value +// or an index for previous offset), computes the correct offset and udpates +// the offset history +static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist); + +// Given an offset, match length, and total output, as well as the frame +// context for the dictionary, determines if the dictionary is used and +// executes the copy operation +static void execute_match_copy(frame_context_t *const ctx, size_t offset, + size_t match_length, size_t total_output, + ostream_t *const out); + /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ size_t ZSTD_decompress(void *const dst, const size_t dst_len, const void *const src, const size_t src_len) { - return ZSTD_decompress_with_dict(dst, dst_len, src, src_len, NULL, 0); + dictionary_t* uninit_dict = create_dictionary(); + size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src, + src_len, uninit_dict); + free_dictionary(uninit_dict); + return decomp_size; } size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, const void *const src, const size_t src_len, - const void *const dict, - const size_t dict_len) { - dictionary_t parsed_dict; - memset(&parsed_dict, 0, sizeof(dictionary_t)); - // dict_len < 8 is not a valid dictionary - if (dict && dict_len > 8) { - parse_dictionary(&parsed_dict, (const u8 *)dict, dict_len); - } + dictionary_t* parsed_dict) { istream_t in = IO_make_istream(src, src_len); ostream_t out = IO_make_ostream(dst, dst_len); @@ -396,11 +391,9 @@ size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, // Multiple frames can be appended into a single file or stream. A frame is // totally independent, has a defined beginning and end, and a set of // parameters which tells the decoder how to decompress it." - while (IO_istream_len(&in) > 0) { - decode_frame(&out, &in, &parsed_dict); - } - free_dictionary(&parsed_dict); + /* this decoder assumes decompression of a single frame */ + decode_frame(&out, &in, parsed_dict); return out.ptr - (u8 *)dst; } @@ -424,30 +417,6 @@ static void decompress_data(frame_context_t *const ctx, ostream_t *const out, static void decode_frame(ostream_t *const out, istream_t *const in, const dictionary_t *const dict) { const u32 magic_number = IO_read_bits(in, 32); - - // Skippable frame - // - // "Magic_Number - // - // 4 Bytes, little-endian format. Value : 0x184D2A5?, which means any value - // from 0x184D2A50 to 0x184D2A5F. All 16 values are valid to identify a - // skippable frame." - if ((magic_number & ~0xFU) == 0x184D2A50U) { - // "Skippable frames allow the insertion of user-defined data into a - // flow of concatenated frames. Its design is pretty straightforward, - // with the sole objective to allow the decoder to quickly skip over - // user-defined data and continue decoding. - // - // Skippable frames defined in this specification are compatible with - // LZ4 ones." - const size_t frame_size = IO_read_bits(in, 32); - - // skip over frame - IO_advance_input(in, frame_size); - - return; - } - // Zstandard frame // // "Magic_Number @@ -460,8 +429,8 @@ static void decode_frame(ostream_t *const out, istream_t *const in, return; } - // not a real frame - ERROR("Invalid magic number"); + // not a real frame or a skippable frame + ERROR("Tried to decode non-ZSTD frame"); } /// Decode a frame that contains compressed data. Not all frames do as there @@ -672,8 +641,8 @@ static void decompress_data(frame_context_t *const ctx, ostream_t *const out, case 0: { // "Raw_Block - this is an uncompressed block. Block_Size is the // number of bytes to read and copy." - const u8 *const read_ptr = IO_read_bytes(in, block_len); - u8 *const write_ptr = IO_write_bytes(out, block_len); + const u8 *const read_ptr = IO_get_read_ptr(in, block_len); + u8 *const write_ptr = IO_get_write_ptr(out, block_len); // Copy the raw data into the output memcpy(write_ptr, read_ptr, block_len); @@ -685,8 +654,8 @@ static void decompress_data(frame_context_t *const ctx, ostream_t *const out, // "RLE_Block - this is a single byte, repeated N times. In which // case, Block_Size is the size to regenerate, while the // "compressed" block is just 1 byte (the byte to repeat)." - const u8 *const read_ptr = IO_read_bytes(in, 1); - u8 *const write_ptr = IO_write_bytes(out, block_len); + const u8 *const read_ptr = IO_get_read_ptr(in, 1); + u8 *const write_ptr = IO_get_write_ptr(out, block_len); // Copy `block_len` copies of `read_ptr[0]` to the output memset(write_ptr, read_ptr[0], block_len); @@ -832,13 +801,13 @@ static size_t decode_literals_simple(istream_t *const in, u8 **const literals, switch (block_type) { case 0: { // "Raw_Literals_Block - Literals are stored uncompressed." - const u8 *const read_ptr = IO_read_bytes(in, size); + const u8 *const read_ptr = IO_get_read_ptr(in, size); memcpy(*literals, read_ptr, size); break; } case 1: { // "RLE_Literals_Block - Literals consist of a single byte value repeated N times." - const u8 *const read_ptr = IO_read_bytes(in, 1); + const u8 *const read_ptr = IO_get_read_ptr(in, 1); memset(*literals, read_ptr[0], size); break; } @@ -949,7 +918,7 @@ static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) { num_symbs = header - 127; const size_t bytes = (num_symbs + 1) / 2; - const u8 *const weight_src = IO_read_bytes(in, bytes); + const u8 *const weight_src = IO_get_read_ptr(in, bytes); for (int i = 0; i < num_symbs; i++) { // "They are encoded forward, 2 @@ -1157,7 +1126,7 @@ static void decompress_sequences(frame_context_t *const ctx, istream_t *in, } const size_t len = IO_istream_len(in); - const u8 *const src = IO_read_bytes(in, len); + const u8 *const src = IO_get_read_ptr(in, len); // "After writing the last bit containing information, the compressor writes // a single 1-bit and then fills the byte with 0-7 0 bits of padding." @@ -1262,7 +1231,7 @@ static void decode_seq_table(FSE_dtable *const table, istream_t *const in, } case seq_rle: { // "RLE_Mode : it's a single code, repeated Number_of_Sequences times." - const u8 symb = IO_read_bytes(in, 1)[0]; + const u8 symb = IO_get_read_ptr(in, 1)[0]; FSE_init_dtable_rle(table, symb); break; } @@ -1303,145 +1272,146 @@ static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, for (size_t i = 0; i < num_sequences; i++) { const sequence_command_t seq = sequences[i]; - { - // If the sequence asks for more literals than are left, the - // sequence must be corrupted - if (seq.literal_length > IO_istream_len(&litstream)) { - CORRUPTION(); - } + const u32 literals_size = copy_literals(seq.literal_length, &litstream, out); + total_output += literals_size; + } - u8 *const write_ptr = IO_write_bytes(out, seq.literal_length); - const u8 *const read_ptr = - IO_read_bytes(&litstream, seq.literal_length); - // Copy literals to output - memcpy(write_ptr, read_ptr, seq.literal_length); + size_t const offset = compute_offset(seq, offset_hist); - total_output += seq.literal_length; - } + size_t const match_length = seq.match_length; - size_t offset; - - // Offsets are special, we need to handle the repeat offsets - if (seq.offset <= 3) { - // "The first 3 values define a repeated offset and we will call - // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3. - // They are sorted in recency order, with Repeated_Offset1 meaning - // 'most recent one'". - - // Use 0 indexing for the array - u32 idx = seq.offset - 1; - if (seq.literal_length == 0) { - // "There is an exception though, when current sequence's - // literals length is 0. In this case, repeated offsets are - // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2, - // Repeated_Offset2 becomes Repeated_Offset3, and - // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte." - idx++; - } + execute_match_copy(ctx, offset, match_length, total_output, out); - if (idx == 0) { - offset = offset_hist[0]; - } else { - // If idx == 3 then literal length was 0 and the offset was 3, - // as per the exception listed above - offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1; - - // If idx == 1 we don't need to modify offset_hist[2], since - // we're using the second-most recent code - if (idx > 1) { - offset_hist[2] = offset_hist[1]; - } - offset_hist[1] = offset_hist[0]; - offset_hist[0] = offset; - } - } else { - // When it's not a repeat offset: - // "if (Offset_Value > 3) offset = Offset_Value - 3;" - offset = seq.offset - 3; + total_output += match_length; + } - // Shift back history - offset_hist[2] = offset_hist[1]; - offset_hist[1] = offset_hist[0]; - offset_hist[0] = offset; - } + // Copy any leftover literals + { + size_t len = IO_istream_len(&litstream); + copy_literals(len, &litstream, out); + total_output += len; + } - size_t match_length = seq.match_length; + ctx->current_total_output = total_output; +} - u8 *write_ptr = IO_write_bytes(out, match_length); - if (total_output <= ctx->header.window_size) { - // In this case offset might go back into the dictionary - if (offset > total_output + ctx->dict_content_len) { - // The offset goes beyond even the dictionary - CORRUPTION(); - } +static u32 copy_literals(const size_t literal_length, istream_t *litstream, + ostream_t *const out) { + // If the sequence asks for more literals than are left, the + // sequence must be corrupted + if (literal_length > IO_istream_len(litstream)) { + CORRUPTION(); + } - if (offset > total_output) { - // "The rest of the dictionary is its content. The content act - // as a "past" in front of data to compress or decompress, so it - // can be referenced in sequence commands." - const size_t dict_copy = - MIN(offset - total_output, match_length); - const size_t dict_offset = - ctx->dict_content_len - (offset - total_output); - - memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy); - write_ptr += dict_copy; - match_length -= dict_copy; - } - } else if (offset > ctx->header.window_size) { - CORRUPTION(); + u8 *const write_ptr = IO_get_write_ptr(out, literal_length); + const u8 *const read_ptr = + IO_get_read_ptr(litstream, literal_length); + // Copy literals to output + memcpy(write_ptr, read_ptr, literal_length); + + return literal_length; +} + +static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) { + size_t offset; + // Offsets are special, we need to handle the repeat offsets + if (seq.offset <= 3) { + // "The first 3 values define a repeated offset and we will call + // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3. + // They are sorted in recency order, with Repeated_Offset1 meaning + // 'most recent one'". + + // Use 0 indexing for the array + u32 idx = seq.offset - 1; + if (seq.literal_length == 0) { + // "There is an exception though, when current sequence's + // literals length is 0. In this case, repeated offsets are + // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2, + // Repeated_Offset2 becomes Repeated_Offset3, and + // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte." + idx++; } - // We must copy byte by byte because the match length might be larger - // than the offset - // ex: if the output so far was "abc", a command with offset=3 and - // match_length=6 would produce "abcabcabc" as the new output - for (size_t i = 0; i < match_length; i++) { - *write_ptr = *(write_ptr - offset); - write_ptr++; + if (idx == 0) { + offset = offset_hist[0]; + } else { + // If idx == 3 then literal length was 0 and the offset was 3, + // as per the exception listed above + offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1; + + // If idx == 1 we don't need to modify offset_hist[2], since + // we're using the second-most recent code + if (idx > 1) { + offset_hist[2] = offset_hist[1]; + } + offset_hist[1] = offset_hist[0]; + offset_hist[0] = offset; } + } else { + // When it's not a repeat offset: + // "if (Offset_Value > 3) offset = Offset_Value - 3;" + offset = seq.offset - 3; - total_output += seq.match_length; + // Shift back history + offset_hist[2] = offset_hist[1]; + offset_hist[1] = offset_hist[0]; + offset_hist[0] = offset; } + return offset; +} - // Copy any leftover literals - { - size_t len = IO_istream_len(&litstream); - u8 *const write_ptr = IO_write_bytes(out, len); - const u8 *const read_ptr = IO_read_bytes(&litstream, len); - memcpy(write_ptr, read_ptr, len); +static void execute_match_copy(frame_context_t *const ctx, size_t offset, + size_t match_length, size_t total_output, + ostream_t *const out) { + u8 *write_ptr = IO_get_write_ptr(out, match_length); + if (total_output <= ctx->header.window_size) { + // In this case offset might go back into the dictionary + if (offset > total_output + ctx->dict_content_len) { + // The offset goes beyond even the dictionary + CORRUPTION(); + } - total_output += len; + if (offset > total_output) { + // "The rest of the dictionary is its content. The content act + // as a "past" in front of data to compress or decompress, so it + // can be referenced in sequence commands." + const size_t dict_copy = + MIN(offset - total_output, match_length); + const size_t dict_offset = + ctx->dict_content_len - (offset - total_output); + + memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy); + write_ptr += dict_copy; + match_length -= dict_copy; + } + } else if (offset > ctx->header.window_size) { + CORRUPTION(); } - ctx->current_total_output = total_output; + // We must copy byte by byte because the match length might be larger + // than the offset + // ex: if the output so far was "abc", a command with offset=3 and + // match_length=6 would produce "abcabcabc" as the new output + for (size_t j = 0; j < match_length; j++) { + *write_ptr = *(write_ptr - offset); + write_ptr++; + } } /******* END SEQUENCE EXECUTION ***********************************************/ /******* OUTPUT SIZE COUNTING *************************************************/ -static void traverse_frame(const frame_header_t *const header, istream_t *const in); - /// Get the decompressed size of an input stream so memory can be allocated in /// advance. -/// This is more complex than the implementation in the reference -/// implementation, as this API allows for the decompression of multiple -/// concatenated frames. +/// This implementation assumes `src` points to a single ZSTD-compressed frame size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { istream_t in = IO_make_istream(src, src_len); - size_t dst_size = 0; - // Each frame header only gives us the size of its frame, so iterate over - // all - // frames - while (IO_istream_len(&in) > 0) { + // get decompressed size from ZSTD frame header + { const u32 magic_number = IO_read_bits(&in, 32); - if ((magic_number & ~0xFU) == 0x184D2A50U) { - // skippable frame, this has no impact on output size - const size_t frame_size = IO_read_bits(&in, 32); - IO_advance_input(&in, frame_size); - } else if (magic_number == 0xFD2FB528U) { + if (magic_number == 0xFD2FB528U) { // ZSTD frame frame_header_t header; parse_frame_header(&header, &in); @@ -1451,68 +1421,42 @@ size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { return -1; } - dst_size += header.frame_content_size; - - // Consume the input from the frame to reach the start of the next - traverse_frame(&header, &in); + return header.frame_content_size; } else { - // not a real frame - ERROR("Invalid magic number"); + // not a real frame or skippable frame + ERROR("ZSTD frame magic number did not match"); } } - - return dst_size; } +/******* END OUTPUT SIZE COUNTING *********************************************/ -/// Iterate over each block in a frame to find the end of it, to get to the -/// start of the next frame -static void traverse_frame(const frame_header_t *const header, istream_t *const in) { - int last_block = 0; - - do { - // Parse the block header - last_block = IO_read_bits(in, 1); - const int block_type = IO_read_bits(in, 2); - const size_t block_len = IO_read_bits(in, 21); - - switch (block_type) { - case 0: // Raw block, block_len bytes - IO_advance_input(in, block_len); - break; - case 1: // RLE block, 1 byte - IO_advance_input(in, 1); - break; - case 2: // Compressed block, compressed size is block_len - IO_advance_input(in, block_len); - break; - case 3: - // Reserved block type - CORRUPTION(); - break; - default: - IMPOSSIBLE(); - } - } while (!last_block); +/******* DICTIONARY PARSING ***************************************************/ +#define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes") +#define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src"); - if (header->content_checksum_flag) { - IO_advance_input(in, 4); +dictionary_t* create_dictionary() { + dictionary_t* dict = calloc(1, sizeof(dictionary_t)); + if (!dict) { + BAD_ALLOC(); } + return dict; } -/******* END OUTPUT SIZE COUNTING *********************************************/ - -/******* DICTIONARY PARSING ***************************************************/ static void init_dictionary_content(dictionary_t *const dict, istream_t *const in); -static void parse_dictionary(dictionary_t *const dict, const u8 *src, +void parse_dictionary(dictionary_t *const dict, const void *src, size_t src_len) { + const u8 *byte_src = (const u8 *)src; memset(dict, 0, sizeof(dictionary_t)); + if (src == NULL) { /* cannot initialize dictionary with null src */ + NULL_SRC(); + } if (src_len < 8) { - INP_SIZE(); + DICT_SIZE_ERROR(); } - istream_t in = IO_make_istream(src, src_len); + istream_t in = IO_make_istream(byte_src, src_len); const u32 magic_number = IO_read_bits(&in, 32); if (magic_number != 0xEC30A437) { @@ -1564,13 +1508,13 @@ static void init_dictionary_content(dictionary_t *const dict, BAD_ALLOC(); } - const u8 *const content = IO_read_bytes(in, dict->content_size); + const u8 *const content = IO_get_read_ptr(in, dict->content_size); memcpy(dict->content, content, dict->content_size); } /// Free an allocated dictionary -static void free_dictionary(dictionary_t *const dict) { +void free_dictionary(dictionary_t *const dict) { HUF_free_dtable(&dict->literals_dtable); FSE_free_dtable(&dict->ll_dtable); FSE_free_dtable(&dict->of_dtable); @@ -1579,6 +1523,8 @@ static void free_dictionary(dictionary_t *const dict) { free(dict->content); memset(dict, 0, sizeof(dictionary_t)); + + free(dict); } /******* END DICTIONARY PARSING ***********************************************/ @@ -1657,7 +1603,7 @@ static inline size_t IO_istream_len(const istream_t *const in) { /// Returns a pointer where `len` bytes can be read, and advances the internal /// state. The stream must be byte aligned. -static inline const u8 *IO_read_bytes(istream_t *const in, size_t len) { +static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) { if (len > in->len) { INP_SIZE(); } @@ -1671,7 +1617,7 @@ static inline const u8 *IO_read_bytes(istream_t *const in, size_t len) { return ptr; } /// Returns a pointer to write `len` bytes to, and advances the internal state -static inline u8 *IO_write_bytes(ostream_t *const out, size_t len) { +static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) { if (len > out->len) { OUT_SIZE(); } @@ -1710,7 +1656,7 @@ static inline istream_t IO_make_istream(const u8 *in, size_t len) { /// `in` must be byte aligned static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) { // Consume `len` bytes of the parent stream - const u8 *const ptr = IO_read_bytes(in, len); + const u8 *const ptr = IO_get_read_ptr(in, len); // Make a substream using the pointer to those `len` bytes return IO_make_istream(ptr, len); @@ -1814,7 +1760,7 @@ static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, if (len == 0) { INP_SIZE(); } - const u8 *const src = IO_read_bytes(in, len); + const u8 *const src = IO_get_read_ptr(in, len); // "Each bitstream must be read backward, that is starting from the end down // to the beginning. Therefore it's necessary to know the size of each @@ -2065,7 +2011,7 @@ static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, if (len == 0) { INP_SIZE(); } - const u8 *const src = IO_read_bytes(in, len); + const u8 *const src = IO_get_read_ptr(in, len); // "Each bitstream must be read backward, that is starting from the end down // to the beginning. Therefore it's necessary to know the size of each @@ -2192,7 +2138,7 @@ static void FSE_init_dtable(FSE_dtable *const dtable, } // Now we can fill baseline and num bits - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { u8 symbol = dtable->symbols[i]; u16 next_state_desc = state_desc[symbol]++; // Fills in the table appropriately, next_state_desc increases by symbol @@ -2355,4 +2301,3 @@ static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16)); } /******* END FSE PRIMITIVES ***************************************************/ - |