1/*
2 * Copyright (c) Yann Collet, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 * You may select, at your option, one of the above-listed licenses.
9 */
10
11
12
13/* **************************************
14*  Compiler Warnings
15****************************************/
16#ifdef _MSC_VER
17#  pragma warning(disable : 4127)    /* disable: C4127: conditional expression is constant */
18#endif
19
20
21/*-*************************************
22*  Includes
23***************************************/
24#include "platform.h"       /* Large Files support */
25#include "util.h"           /* UTIL_getFileSize, UTIL_getTotalFileSize */
26#include <stdlib.h>         /* malloc, free */
27#include <string.h>         /* memset */
28#include <stdio.h>          /* fprintf, fopen, ftello64 */
29#include <errno.h>          /* errno */
30#include <assert.h>
31
32#include "timefn.h"         /* UTIL_time_t, UTIL_clockSpanMicro, UTIL_getTime */
33#include "../lib/common/mem.h"  /* read */
34#include "dibio.h"
35
36
37/*-*************************************
38*  Constants
39***************************************/
40#define KB *(1 <<10)
41#define MB *(1 <<20)
42#define GB *(1U<<30)
43
44#define SAMPLESIZE_MAX (128 KB)
45#define MEMMULT 11    /* rough estimation : memory cost to analyze 1 byte of sample */
46#define COVER_MEMMULT 9    /* rough estimation : memory cost to analyze 1 byte of sample */
47#define FASTCOVER_MEMMULT 1    /* rough estimation : memory cost to analyze 1 byte of sample */
48static const size_t g_maxMemory = (sizeof(size_t) == 4) ? (2 GB - 64 MB) : ((size_t)(512 MB) << sizeof(size_t));
49
50#define NOISELENGTH 32
51#define MAX_SAMPLES_SIZE (2 GB) /* training dataset limited to 2GB */
52
53
54/*-*************************************
55*  Console display
56***************************************/
57#define DISPLAY(...)         fprintf(stderr, __VA_ARGS__)
58#define DISPLAYLEVEL(l, ...) if (displayLevel>=l) { DISPLAY(__VA_ARGS__); }
59
60static const U64 g_refreshRate = SEC_TO_MICRO / 6;
61static UTIL_time_t g_displayClock = UTIL_TIME_INITIALIZER;
62
63#define DISPLAYUPDATE(l, ...) { if (displayLevel>=l) { \
64            if ((UTIL_clockSpanMicro(g_displayClock) > g_refreshRate) || (displayLevel>=4)) \
65            { g_displayClock = UTIL_getTime(); DISPLAY(__VA_ARGS__); \
66            if (displayLevel>=4) fflush(stderr); } } }
67
68/*-*************************************
69*  Exceptions
70***************************************/
71#ifndef DEBUG
72#  define DEBUG 0
73#endif
74#define DEBUGOUTPUT(...) if (DEBUG) DISPLAY(__VA_ARGS__);
75#define EXM_THROW(error, ...)                                             \
76{                                                                         \
77    DEBUGOUTPUT("Error defined at %s, line %i : \n", __FILE__, __LINE__); \
78    DISPLAY("Error %i : ", error);                                        \
79    DISPLAY(__VA_ARGS__);                                                 \
80    DISPLAY("\n");                                                        \
81    exit(error);                                                          \
82}
83
84
85/* ********************************************************
86*  Helper functions
87**********************************************************/
88#undef MIN
89#define MIN(a,b)    ((a) < (b) ? (a) : (b))
90
91/**
92  Returns the size of a file.
93  If error returns -1.
94*/
95static S64 DiB_getFileSize (const char * fileName)
96{
97    U64 const fileSize = UTIL_getFileSize(fileName);
98    return (fileSize == UTIL_FILESIZE_UNKNOWN) ? -1 : (S64)fileSize;
99}
100
101/* ********************************************************
102*  File related operations
103**********************************************************/
104/** DiB_loadFiles() :
105 *  load samples from files listed in fileNamesTable into buffer.
106 *  works even if buffer is too small to load all samples.
107 *  Also provides the size of each sample into sampleSizes table
108 *  which must be sized correctly, using DiB_fileStats().
109 * @return : nb of samples effectively loaded into `buffer`
110 * *bufferSizePtr is modified, it provides the amount data loaded within buffer.
111 *  sampleSizes is filled with the size of each sample.
112 */
113static int DiB_loadFiles(
114    void* buffer, size_t* bufferSizePtr,
115    size_t* sampleSizes, int sstSize,
116    const char** fileNamesTable, int nbFiles,
117    size_t targetChunkSize, int displayLevel )
118{
119    char* const buff = (char*)buffer;
120    size_t totalDataLoaded = 0;
121    int nbSamplesLoaded = 0;
122    int fileIndex = 0;
123    FILE * f = NULL;
124
125    assert(targetChunkSize <= SAMPLESIZE_MAX);
126
127    while ( nbSamplesLoaded < sstSize && fileIndex < nbFiles ) {
128        size_t fileDataLoaded;
129        S64 const fileSize = DiB_getFileSize(fileNamesTable[fileIndex]);
130        if (fileSize <= 0) /* skip if zero-size or file error */
131            continue;
132
133        f = fopen( fileNamesTable[fileIndex], "rb");
134        if (f == NULL)
135            EXM_THROW(10, "zstd: dictBuilder: %s %s ", fileNamesTable[fileIndex], strerror(errno));
136        DISPLAYUPDATE(2, "Loading %s...       \r", fileNamesTable[fileIndex]);
137
138        /* Load the first chunk of data from the file */
139        fileDataLoaded = targetChunkSize > 0 ?
140                            (size_t)MIN(fileSize, (S64)targetChunkSize) :
141                            (size_t)MIN(fileSize, SAMPLESIZE_MAX );
142        if (totalDataLoaded + fileDataLoaded > *bufferSizePtr)
143            break;
144        if (fread( buff+totalDataLoaded, 1, fileDataLoaded, f ) != fileDataLoaded)
145            EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]);
146        sampleSizes[nbSamplesLoaded++] = fileDataLoaded;
147        totalDataLoaded += fileDataLoaded;
148
149        /* If file-chunking is enabled, load the rest of the file as more samples */
150        if (targetChunkSize > 0) {
151            while( (S64)fileDataLoaded < fileSize && nbSamplesLoaded < sstSize ) {
152                size_t const chunkSize = MIN((size_t)(fileSize-fileDataLoaded), targetChunkSize);
153                if (totalDataLoaded + chunkSize > *bufferSizePtr) /* buffer is full */
154                    break;
155
156                if (fread( buff+totalDataLoaded, 1, chunkSize, f ) != chunkSize)
157                    EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]);
158                sampleSizes[nbSamplesLoaded++] = chunkSize;
159                totalDataLoaded += chunkSize;
160                fileDataLoaded += chunkSize;
161            }
162        }
163        fileIndex += 1;
164        fclose(f); f = NULL;
165    }
166    if (f != NULL)
167        fclose(f);
168
169    DISPLAYLEVEL(2, "\r%79s\r", "");
170    DISPLAYLEVEL(4, "Loaded %d KB total training data, %d nb samples \n",
171        (int)(totalDataLoaded / (1 KB)), nbSamplesLoaded );
172    *bufferSizePtr = totalDataLoaded;
173    return nbSamplesLoaded;
174}
175
176#define DiB_rotl32(x,r) ((x << r) | (x >> (32 - r)))
177static U32 DiB_rand(U32* src)
178{
179    static const U32 prime1 = 2654435761U;
180    static const U32 prime2 = 2246822519U;
181    U32 rand32 = *src;
182    rand32 *= prime1;
183    rand32 ^= prime2;
184    rand32  = DiB_rotl32(rand32, 13);
185    *src = rand32;
186    return rand32 >> 5;
187}
188
189/* DiB_shuffle() :
190 * shuffle a table of file names in a semi-random way
191 * It improves dictionary quality by reducing "locality" impact, so if sample set is very large,
192 * it will load random elements from it, instead of just the first ones. */
193static void DiB_shuffle(const char** fileNamesTable, unsigned nbFiles) {
194    U32 seed = 0xFD2FB528;
195    unsigned i;
196    assert(nbFiles >= 1);
197    for (i = nbFiles - 1; i > 0; --i) {
198        unsigned const j = DiB_rand(&seed) % (i + 1);
199        const char* const tmp = fileNamesTable[j];
200        fileNamesTable[j] = fileNamesTable[i];
201        fileNamesTable[i] = tmp;
202    }
203}
204
205
206/*-********************************************************
207*  Dictionary training functions
208**********************************************************/
209static size_t DiB_findMaxMem(unsigned long long requiredMem)
210{
211    size_t const step = 8 MB;
212    void* testmem = NULL;
213
214    requiredMem = (((requiredMem >> 23) + 1) << 23);
215    requiredMem += step;
216    if (requiredMem > g_maxMemory) requiredMem = g_maxMemory;
217
218    while (!testmem) {
219        testmem = malloc((size_t)requiredMem);
220        requiredMem -= step;
221    }
222
223    free(testmem);
224    return (size_t)requiredMem;
225}
226
227
228static void DiB_fillNoise(void* buffer, size_t length)
229{
230    unsigned const prime1 = 2654435761U;
231    unsigned const prime2 = 2246822519U;
232    unsigned acc = prime1;
233    size_t p=0;
234
235    for (p=0; p<length; p++) {
236        acc *= prime2;
237        ((unsigned char*)buffer)[p] = (unsigned char)(acc >> 21);
238    }
239}
240
241
242static void DiB_saveDict(const char* dictFileName,
243                         const void* buff, size_t buffSize)
244{
245    FILE* const f = fopen(dictFileName, "wb");
246    if (f==NULL) EXM_THROW(3, "cannot open %s ", dictFileName);
247
248    { size_t const n = fwrite(buff, 1, buffSize, f);
249      if (n!=buffSize) EXM_THROW(4, "%s : write error", dictFileName) }
250
251    { size_t const n = (size_t)fclose(f);
252      if (n!=0) EXM_THROW(5, "%s : flush error", dictFileName) }
253}
254
255typedef struct {
256    S64 totalSizeToLoad;
257    int nbSamples;
258    int oneSampleTooLarge;
259} fileStats;
260
261/*! DiB_fileStats() :
262 *  Given a list of files, and a chunkSize (0 == no chunk, whole files)
263 *  provides the amount of data to be loaded and the resulting nb of samples.
264 *  This is useful primarily for allocation purpose => sample buffer, and sample sizes table.
265 */
266static fileStats DiB_fileStats(const char** fileNamesTable, int nbFiles, size_t chunkSize, int displayLevel)
267{
268    fileStats fs;
269    int n;
270    memset(&fs, 0, sizeof(fs));
271
272    // We assume that if chunking is requested, the chunk size is < SAMPLESIZE_MAX
273    assert( chunkSize <= SAMPLESIZE_MAX );
274
275    for (n=0; n<nbFiles; n++) {
276      S64 const fileSize = DiB_getFileSize(fileNamesTable[n]);
277      // TODO: is there a minimum sample size? What if the file is 1-byte?
278      if (fileSize == 0) {
279        DISPLAYLEVEL(3, "Sample file '%s' has zero size, skipping...\n", fileNamesTable[n]);
280        continue;
281      }
282
283      /* the case where we are breaking up files in sample chunks */
284      if (chunkSize > 0)
285      {
286        // TODO: is there a minimum sample size? Can we have a 1-byte sample?
287        fs.nbSamples += (int)((fileSize + chunkSize-1) / chunkSize);
288        fs.totalSizeToLoad += fileSize;
289      }
290      else {
291      /* the case where one file is one sample */
292        if (fileSize > SAMPLESIZE_MAX) {
293          /* flag excessively large sample files */
294          fs.oneSampleTooLarge |= (fileSize > 2*SAMPLESIZE_MAX);
295
296          /* Limit to the first SAMPLESIZE_MAX (128kB) of the file */
297          DISPLAYLEVEL(3, "Sample file '%s' is too large, limiting to %d KB",
298              fileNamesTable[n], SAMPLESIZE_MAX / (1 KB));
299        }
300        fs.nbSamples += 1;
301        fs.totalSizeToLoad += MIN(fileSize, SAMPLESIZE_MAX);
302      }
303    }
304    DISPLAYLEVEL(4, "Found training data %d files, %d KB, %d samples\n", nbFiles, (int)(fs.totalSizeToLoad / (1 KB)), fs.nbSamples);
305    return fs;
306}
307
308int DiB_trainFromFiles(const char* dictFileName, size_t maxDictSize,
309                       const char** fileNamesTable, int nbFiles, size_t chunkSize,
310                       ZDICT_legacy_params_t* params, ZDICT_cover_params_t* coverParams,
311                       ZDICT_fastCover_params_t* fastCoverParams, int optimize, unsigned memLimit)
312{
313    fileStats fs;
314    size_t* sampleSizes; /* vector of sample sizes. Each sample can be up to SAMPLESIZE_MAX */
315    int nbSamplesLoaded; /* nb of samples effectively loaded in srcBuffer */
316    size_t loadedSize; /* total data loaded in srcBuffer for all samples */
317    void* srcBuffer /* contiguous buffer with training data/samples */;
318    void* const dictBuffer = malloc(maxDictSize);
319    int result = 0;
320
321    int const displayLevel = params ? params->zParams.notificationLevel :
322        coverParams ? coverParams->zParams.notificationLevel :
323        fastCoverParams ? fastCoverParams->zParams.notificationLevel : 0;
324
325    /* Shuffle input files before we start assessing how much sample datA to load.
326       The purpose of the shuffle is to pick random samples when the sample
327       set is larger than what we can load in memory. */
328    DISPLAYLEVEL(3, "Shuffling input files\n");
329    DiB_shuffle(fileNamesTable, nbFiles);
330
331    /* Figure out how much sample data to load with how many samples */
332    fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel);
333
334    {
335        int const memMult = params ? MEMMULT :
336                            coverParams ? COVER_MEMMULT:
337                            FASTCOVER_MEMMULT;
338        size_t const maxMem =  DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult;
339        /* Limit the size of the training data to the free memory */
340        /* Limit the size of the training data to 2GB */
341        /* TODO: there is opportunity to stop DiB_fileStats() early when the data limit is reached */
342        loadedSize = (size_t)MIN( MIN((S64)maxMem, fs.totalSizeToLoad), MAX_SAMPLES_SIZE );
343        if (memLimit != 0) {
344            DISPLAYLEVEL(2, "!  Warning : setting manual memory limit for dictionary training data at %u MB \n",
345                (unsigned)(memLimit / (1 MB)));
346            loadedSize = (size_t)MIN(loadedSize, memLimit);
347        }
348        srcBuffer = malloc(loadedSize+NOISELENGTH);
349        sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t));
350    }
351
352    /* Checks */
353    if ((!sampleSizes) || (!srcBuffer) || (!dictBuffer))
354        EXM_THROW(12, "not enough memory for DiB_trainFiles");   /* should not happen */
355    if (fs.oneSampleTooLarge) {
356        DISPLAYLEVEL(2, "!  Warning : some sample(s) are very large \n");
357        DISPLAYLEVEL(2, "!  Note that dictionary is only useful for small samples. \n");
358        DISPLAYLEVEL(2, "!  As a consequence, only the first %u bytes of each sample are loaded \n", SAMPLESIZE_MAX);
359    }
360    if (fs.nbSamples < 5) {
361        DISPLAYLEVEL(2, "!  Warning : nb of samples too low for proper processing ! \n");
362        DISPLAYLEVEL(2, "!  Please provide _one file per sample_. \n");
363        DISPLAYLEVEL(2, "!  Alternatively, split files into fixed-size blocks representative of samples, with -B# \n");
364        EXM_THROW(14, "nb of samples too low");   /* we now clearly forbid this case */
365    }
366    if (fs.totalSizeToLoad < (S64)maxDictSize * 8) {
367        DISPLAYLEVEL(2, "!  Warning : data size of samples too small for target dictionary size \n");
368        DISPLAYLEVEL(2, "!  Samples should be about 100x larger than target dictionary size \n");
369    }
370
371    /* init */
372    if ((S64)loadedSize < fs.totalSizeToLoad)
373        DISPLAYLEVEL(1, "Training samples set too large (%u MB); training on %u MB only...\n",
374            (unsigned)(fs.totalSizeToLoad / (1 MB)),
375            (unsigned)(loadedSize / (1 MB)));
376
377    /* Load input buffer */
378    nbSamplesLoaded = DiB_loadFiles(
379        srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable,
380        nbFiles, chunkSize, displayLevel);
381
382    {   size_t dictSize;
383        if (params) {
384            DiB_fillNoise((char*)srcBuffer + loadedSize, NOISELENGTH);   /* guard band, for end of buffer condition */
385            dictSize = ZDICT_trainFromBuffer_legacy(dictBuffer, maxDictSize,
386                                                    srcBuffer, sampleSizes, nbSamplesLoaded,
387                                                    *params);
388        } else if (coverParams) {
389            if (optimize) {
390              dictSize = ZDICT_optimizeTrainFromBuffer_cover(dictBuffer, maxDictSize,
391                                                             srcBuffer, sampleSizes, nbSamplesLoaded,
392                                                             coverParams);
393              if (!ZDICT_isError(dictSize)) {
394                  unsigned splitPercentage = (unsigned)(coverParams->splitPoint * 100);
395                  DISPLAYLEVEL(2, "k=%u\nd=%u\nsteps=%u\nsplit=%u\n", coverParams->k, coverParams->d,
396                              coverParams->steps, splitPercentage);
397              }
398            } else {
399              dictSize = ZDICT_trainFromBuffer_cover(dictBuffer, maxDictSize, srcBuffer,
400                                                     sampleSizes, nbSamplesLoaded, *coverParams);
401            }
402        } else {
403            assert(fastCoverParams != NULL);
404            if (optimize) {
405              dictSize = ZDICT_optimizeTrainFromBuffer_fastCover(dictBuffer, maxDictSize,
406                                                              srcBuffer, sampleSizes, nbSamplesLoaded,
407                                                              fastCoverParams);
408              if (!ZDICT_isError(dictSize)) {
409                unsigned splitPercentage = (unsigned)(fastCoverParams->splitPoint * 100);
410                DISPLAYLEVEL(2, "k=%u\nd=%u\nf=%u\nsteps=%u\nsplit=%u\naccel=%u\n", fastCoverParams->k,
411                            fastCoverParams->d, fastCoverParams->f, fastCoverParams->steps, splitPercentage,
412                            fastCoverParams->accel);
413              }
414            } else {
415              dictSize = ZDICT_trainFromBuffer_fastCover(dictBuffer, maxDictSize, srcBuffer,
416                                                        sampleSizes, nbSamplesLoaded, *fastCoverParams);
417            }
418        }
419        if (ZDICT_isError(dictSize)) {
420            DISPLAYLEVEL(1, "dictionary training failed : %s \n", ZDICT_getErrorName(dictSize));   /* should not happen */
421            result = 1;
422            goto _cleanup;
423        }
424        /* save dict */
425        DISPLAYLEVEL(2, "Save dictionary of size %u into file %s \n", (unsigned)dictSize, dictFileName);
426        DiB_saveDict(dictFileName, dictBuffer, dictSize);
427    }
428
429    /* clean up */
430_cleanup:
431    free(srcBuffer);
432    free(sampleSizes);
433    free(dictBuffer);
434    return result;
435}
436