Limit train samples (#2809)

* Limit training samples size to 2GB

* simplified DISPLAYLEVEL() macro to use global vqriable instead of local.

* refactored training samples loading

* fixed compiler warning

* addressed comments from the pull request

* addressed @terrelln comments

* missed some fixes

* fixed type mismatch

* Fixed bug passing estimated number of samples rather insted of the loaded number of samples.
Changed unit conversion not to use bit-shifts.

* fixed a declaration after code

* fixed type conversion compile errors

* fixed more type castting

* fixed more type mismatching

* changed sizes type to size_t

* move type casting

* more type cast fixes
diff --git a/programs/dibio.c b/programs/dibio.c
index d6c9f6d..49fa211 100644
--- a/programs/dibio.c
+++ b/programs/dibio.c
@@ -49,6 +49,7 @@
 static const size_t g_maxMemory = (sizeof(size_t) == 4) ? (2 GB - 64 MB) : ((size_t)(512 MB) << sizeof(size_t));
 
 #define NOISELENGTH 32
+#define MAX_SAMPLES_SIZE (2 GB) /* training dataset limited to 2GB */
 
 
 /*-*************************************
@@ -88,6 +89,15 @@
 #undef MIN
 #define MIN(a,b)    ((a) < (b) ? (a) : (b))
 
+/**
+  Returns the size of a file.
+  If error returns -1.
+*/
+static S64 DiB_getFileSize (const char * fileName)
+{
+    U64 const fileSize = UTIL_getFileSize(fileName);
+    return (fileSize == UTIL_FILESIZE_UNKNOWN) ? -1 : (S64)fileSize;
+}
 
 /* ********************************************************
 *  File related operations
@@ -101,47 +111,67 @@
  * *bufferSizePtr is modified, it provides the amount data loaded within buffer.
  *  sampleSizes is filled with the size of each sample.
  */
-static unsigned DiB_loadFiles(void* buffer, size_t* bufferSizePtr,
-                              size_t* sampleSizes, unsigned sstSize,
-                              const char** fileNamesTable, unsigned nbFiles, size_t targetChunkSize,
-                              unsigned displayLevel)
+static int DiB_loadFiles(
+    void* buffer, size_t* bufferSizePtr,
+    size_t* sampleSizes, int sstSize,
+    const char** fileNamesTable, int nbFiles,
+    size_t targetChunkSize, int displayLevel )
 {
     char* const buff = (char*)buffer;
-    size_t pos = 0;
-    unsigned nbLoadedChunks = 0, fileIndex;
+    size_t totalDataLoaded = 0;
+    int nbSamplesLoaded = 0;
+    int fileIndex = 0;
+    FILE * f = NULL;
 
-    for (fileIndex=0; fileIndex<nbFiles; fileIndex++) {
-        const char* const fileName = fileNamesTable[fileIndex];
-        unsigned long long const fs64 = UTIL_getFileSize(fileName);
-        unsigned long long remainingToLoad = (fs64 == UTIL_FILESIZE_UNKNOWN) ? 0 : fs64;
-        U32 const nbChunks = targetChunkSize ? (U32)((fs64 + (targetChunkSize-1)) / targetChunkSize) : 1;
-        U64 const chunkSize = targetChunkSize ? MIN(targetChunkSize, fs64) : fs64;
-        size_t const maxChunkSize = (size_t)MIN(chunkSize, SAMPLESIZE_MAX);
-        U32 cnb;
-        FILE* const f = fopen(fileName, "rb");
-        if (f==NULL) EXM_THROW(10, "zstd: dictBuilder: %s %s ", fileName, strerror(errno));
-        DISPLAYUPDATE(2, "Loading %s...       \r", fileName);
-        for (cnb=0; cnb<nbChunks; cnb++) {
-            size_t const toLoad = (size_t)MIN(maxChunkSize, remainingToLoad);
-            if (toLoad > *bufferSizePtr-pos) break;
-            {   size_t const readSize = fread(buff+pos, 1, toLoad, f);
-                if (readSize != toLoad) EXM_THROW(11, "Pb reading %s", fileName);
-                pos += readSize;
-                sampleSizes[nbLoadedChunks++] = toLoad;
-                remainingToLoad -= targetChunkSize;
-                if (nbLoadedChunks == sstSize) { /* no more space left in sampleSizes table */
-                    fileIndex = nbFiles;  /* stop there */
+    assert(targetChunkSize <= SAMPLESIZE_MAX);
+
+    while ( nbSamplesLoaded < sstSize && fileIndex < nbFiles ) {
+        size_t fileDataLoaded;
+        S64 const fileSize = DiB_getFileSize(fileNamesTable[fileIndex]);
+        if (fileSize <= 0) /* skip if zero-size or file error */
+            continue;
+
+        f = fopen( fileNamesTable[fileIndex], "rb");
+        if (f == NULL)
+            EXM_THROW(10, "zstd: dictBuilder: %s %s ", fileNamesTable[fileIndex], strerror(errno));
+        DISPLAYUPDATE(2, "Loading %s...       \r", fileNamesTable[fileIndex]);
+
+        /* Load the first chunk of data from the file */
+        fileDataLoaded = targetChunkSize > 0 ?
+                            (size_t)MIN(fileSize, (S64)targetChunkSize) :
+                            (size_t)MIN(fileSize, SAMPLESIZE_MAX );
+        if (totalDataLoaded + fileDataLoaded > *bufferSizePtr)
+            break;
+        if (fread( buff+totalDataLoaded, 1, fileDataLoaded, f ) != fileDataLoaded)
+            EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]);
+        sampleSizes[nbSamplesLoaded++] = fileDataLoaded;
+        totalDataLoaded += fileDataLoaded;
+
+        /* If file-chunking is enabled, load the rest of the file as more samples */
+        if (targetChunkSize > 0) {
+            while( (S64)fileDataLoaded < fileSize && nbSamplesLoaded < sstSize ) {
+                size_t const chunkSize = MIN((size_t)(fileSize-fileDataLoaded), targetChunkSize);
+                if (totalDataLoaded + chunkSize > *bufferSizePtr) /* buffer is full */
                     break;
-                }
-                if (toLoad < targetChunkSize) {
-                    fseek(f, (long)(targetChunkSize - toLoad), SEEK_CUR);
-        }   }   }
-        fclose(f);
+
+                if (fread( buff+totalDataLoaded, 1, chunkSize, f ) != chunkSize)
+                    EXM_THROW(11, "Pb reading %s", fileNamesTable[fileIndex]);
+                sampleSizes[nbSamplesLoaded++] = chunkSize;
+                totalDataLoaded += chunkSize;
+                fileDataLoaded += chunkSize;
+            }
+        }
+        fileIndex += 1;
+        fclose(f); f = NULL;
     }
+    if (f != NULL)
+        fclose(f);
+
     DISPLAYLEVEL(2, "\r%79s\r", "");
-    *bufferSizePtr = pos;
-    DISPLAYLEVEL(4, "loaded : %u KB \n", (unsigned)(pos >> 10))
-    return nbLoadedChunks;
+    DISPLAYLEVEL(4, "Loaded %d KB total training data, %d nb samples \n",
+        (int)(totalDataLoaded / (1 KB)), nbSamplesLoaded );
+    *bufferSizePtr = totalDataLoaded;
+    return nbSamplesLoaded;
 }
 
 #define DiB_rotl32(x,r) ((x << r) | (x >> (32 - r)))
@@ -223,11 +253,10 @@
       if (n!=0) EXM_THROW(5, "%s : flush error", dictFileName) }
 }
 
-
 typedef struct {
-    U64 totalSizeToLoad;
-    unsigned oneSampleTooLarge;
-    unsigned nbSamples;
+    S64 totalSizeToLoad;
+    int nbSamples;
+    int oneSampleTooLarge;
 } fileStats;
 
 /*! DiB_fileStats() :
@@ -235,46 +264,87 @@
  *  provides the amount of data to be loaded and the resulting nb of samples.
  *  This is useful primarily for allocation purpose => sample buffer, and sample sizes table.
  */
-static fileStats DiB_fileStats(const char** fileNamesTable, unsigned nbFiles, size_t chunkSize, unsigned displayLevel)
+static fileStats DiB_fileStats(const char** fileNamesTable, int nbFiles, size_t chunkSize, int displayLevel)
 {
     fileStats fs;
-    unsigned n;
+    int n;
     memset(&fs, 0, sizeof(fs));
+
+    // We assume that if chunking is requsted, the chunk size is < SAMPLESIZE_MAX
+    assert( chunkSize <= SAMPLESIZE_MAX );
+
     for (n=0; n<nbFiles; n++) {
-        U64 const fileSize = UTIL_getFileSize(fileNamesTable[n]);
-        U64 const srcSize = (fileSize == UTIL_FILESIZE_UNKNOWN) ? 0 : fileSize;
-        U32 const nbSamples = (U32)(chunkSize ? (srcSize + (chunkSize-1)) / chunkSize : 1);
-        U64 const chunkToLoad = chunkSize ? MIN(chunkSize, srcSize) : srcSize;
-        size_t const cappedChunkSize = (size_t)MIN(chunkToLoad, SAMPLESIZE_MAX);
-        fs.totalSizeToLoad += cappedChunkSize * nbSamples;
-        fs.oneSampleTooLarge |= (chunkSize > 2*SAMPLESIZE_MAX);
-        fs.nbSamples += nbSamples;
+      S64 const fileSize = DiB_getFileSize(fileNamesTable[n]);
+      // TODO: is there a minimum sample size? What if the file is 1-byte?
+      if (fileSize == 0) {
+        DISPLAYLEVEL(3, "Sample file '%s' has zero size, skipping...\n", fileNamesTable[n]);
+        continue;
+      }
+
+      /* the case where we are breaking up files in sample chunks */
+      if (chunkSize > 0)
+      {
+        // TODO: is there a minimum sample size? Can we have a 1-byte sample?
+        fs.nbSamples += (int)((fileSize + chunkSize-1) / chunkSize);
+        fs.totalSizeToLoad += fileSize;
+      }
+      else {
+      /* the case where one file is one sample */
+        if (fileSize > SAMPLESIZE_MAX) {
+          /* flag excessively large sample files */
+          fs.oneSampleTooLarge |= (fileSize > 2*SAMPLESIZE_MAX);
+
+          /* Limit to the first SAMPLESIZE_MAX (128kB) of the file */
+          DISPLAYLEVEL(3, "Sample file '%s' is too large, limiting to %d KB",
+              fileNamesTable[n], SAMPLESIZE_MAX / (1 KB));
+        }
+        fs.nbSamples += 1;
+        fs.totalSizeToLoad += MIN(fileSize, SAMPLESIZE_MAX);
+      }
     }
-    DISPLAYLEVEL(4, "Preparing to load : %u KB \n", (unsigned)(fs.totalSizeToLoad >> 10));
+    DISPLAYLEVEL(4, "Found training data %d files, %d KB, %d samples\n", nbFiles, (int)(fs.totalSizeToLoad / (1 KB)), fs.nbSamples);
     return fs;
 }
 
-
-int DiB_trainFromFiles(const char* dictFileName, unsigned maxDictSize,
-                       const char** fileNamesTable, unsigned nbFiles, size_t chunkSize,
+int DiB_trainFromFiles(const char* dictFileName, size_t maxDictSize,
+                       const char** fileNamesTable, int nbFiles, size_t chunkSize,
                        ZDICT_legacy_params_t* params, ZDICT_cover_params_t* coverParams,
                        ZDICT_fastCover_params_t* fastCoverParams, int optimize)
 {
-    unsigned const displayLevel = params ? params->zParams.notificationLevel :
-                        coverParams ? coverParams->zParams.notificationLevel :
-                        fastCoverParams ? fastCoverParams->zParams.notificationLevel :
-                        0;   /* should never happen */
+    fileStats fs;
+    size_t* sampleSizes; /* vector of sample sizes. Each sample can be up to SAMPLESIZE_MAX */
+    int nbSamplesLoaded; /* nb of samples effectively loaded in srcBuffer */
+    size_t loadedSize; /* total data loaded in srcBuffer for all samples */
+    void* srcBuffer /* contiguous buffer with training data/samples */;
     void* const dictBuffer = malloc(maxDictSize);
-    fileStats const fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel);
-    size_t* const sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t));
-    size_t const memMult = params ? MEMMULT :
-                           coverParams ? COVER_MEMMULT:
-                           FASTCOVER_MEMMULT;
-    size_t const maxMem =  DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult;
-    size_t loadedSize = (size_t) MIN ((unsigned long long)maxMem, fs.totalSizeToLoad);
-    void* const srcBuffer = malloc(loadedSize+NOISELENGTH);
     int result = 0;
 
+    int const displayLevel = params ? params->zParams.notificationLevel :
+        coverParams ? coverParams->zParams.notificationLevel :
+        fastCoverParams ? fastCoverParams->zParams.notificationLevel : 0;
+
+    /* Shuffle input files before we start assessing how much sample datA to load.
+       The purpose of the shuffle is to pick random samples when the sample
+       set is larger than what we can load in memory. */
+    DISPLAYLEVEL(3, "Shuffling input files\n");
+    DiB_shuffle(fileNamesTable, nbFiles);
+
+    /* Figure out how much sample data to load with how many samples */
+    fs = DiB_fileStats(fileNamesTable, nbFiles, chunkSize, displayLevel);
+
+    {
+        int const memMult = params ? MEMMULT :
+                            coverParams ? COVER_MEMMULT:
+                            FASTCOVER_MEMMULT;
+        size_t const maxMem =  DiB_findMaxMem(fs.totalSizeToLoad * memMult) / memMult;
+        /* Limit the size of the training data to the free memory */
+        /* Limit the size of the training data to 2GB */
+        /* TODO: there is oportunity to stop DiB_fileStats() early when the data limit is reached */
+        loadedSize = (size_t)MIN( MIN((S64)maxMem, fs.totalSizeToLoad), MAX_SAMPLES_SIZE );
+        srcBuffer = malloc(loadedSize+NOISELENGTH);
+        sampleSizes = (size_t*)malloc(fs.nbSamples * sizeof(size_t));
+    }
+
     /* Checks */
     if ((!sampleSizes) || (!srcBuffer) || (!dictBuffer))
         EXM_THROW(12, "not enough memory for DiB_trainFiles");   /* should not happen */
@@ -289,31 +359,32 @@
         DISPLAYLEVEL(2, "!  Alternatively, split files into fixed-size blocks representative of samples, with -B# \n");
         EXM_THROW(14, "nb of samples too low");   /* we now clearly forbid this case */
     }
-    if (fs.totalSizeToLoad < (unsigned long long)maxDictSize * 8) {
+    if (fs.totalSizeToLoad < (S64)maxDictSize * 8) {
         DISPLAYLEVEL(2, "!  Warning : data size of samples too small for target dictionary size \n");
         DISPLAYLEVEL(2, "!  Samples should be about 100x larger than target dictionary size \n");
     }
 
     /* init */
-    if (loadedSize < fs.totalSizeToLoad)
-        DISPLAYLEVEL(1, "Not enough memory; training on %u MB only...\n", (unsigned)(loadedSize >> 20));
+    if ((S64)loadedSize < fs.totalSizeToLoad)
+        DISPLAYLEVEL(1, "Training samples set too large (%u MB); training on %u MB only...\n",
+            (unsigned)(fs.totalSizeToLoad / (1 MB)),
+            (unsigned)(loadedSize / (1 MB)));
 
     /* Load input buffer */
-    DISPLAYLEVEL(3, "Shuffling input files\n");
-    DiB_shuffle(fileNamesTable, nbFiles);
-
-    DiB_loadFiles(srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable, nbFiles, chunkSize, displayLevel);
+    nbSamplesLoaded = DiB_loadFiles(
+        srcBuffer, &loadedSize, sampleSizes, fs.nbSamples, fileNamesTable,
+        nbFiles, chunkSize, displayLevel);
 
     {   size_t dictSize;
         if (params) {
             DiB_fillNoise((char*)srcBuffer + loadedSize, NOISELENGTH);   /* guard band, for end of buffer condition */
             dictSize = ZDICT_trainFromBuffer_legacy(dictBuffer, maxDictSize,
-                                                    srcBuffer, sampleSizes, fs.nbSamples,
+                                                    srcBuffer, sampleSizes, nbSamplesLoaded,
                                                     *params);
         } else if (coverParams) {
             if (optimize) {
               dictSize = ZDICT_optimizeTrainFromBuffer_cover(dictBuffer, maxDictSize,
-                                                             srcBuffer, sampleSizes, fs.nbSamples,
+                                                             srcBuffer, sampleSizes, nbSamplesLoaded,
                                                              coverParams);
               if (!ZDICT_isError(dictSize)) {
                   unsigned splitPercentage = (unsigned)(coverParams->splitPoint * 100);
@@ -322,13 +393,13 @@
               }
             } else {
               dictSize = ZDICT_trainFromBuffer_cover(dictBuffer, maxDictSize, srcBuffer,
-                                                     sampleSizes, fs.nbSamples, *coverParams);
+                                                     sampleSizes, nbSamplesLoaded, *coverParams);
             }
         } else {
             assert(fastCoverParams != NULL);
             if (optimize) {
               dictSize = ZDICT_optimizeTrainFromBuffer_fastCover(dictBuffer, maxDictSize,
-                                                              srcBuffer, sampleSizes, fs.nbSamples,
+                                                              srcBuffer, sampleSizes, nbSamplesLoaded,
                                                               fastCoverParams);
               if (!ZDICT_isError(dictSize)) {
                 unsigned splitPercentage = (unsigned)(fastCoverParams->splitPoint * 100);
@@ -338,7 +409,7 @@
               }
             } else {
               dictSize = ZDICT_trainFromBuffer_fastCover(dictBuffer, maxDictSize, srcBuffer,
-                                                        sampleSizes, fs.nbSamples, *fastCoverParams);
+                                                        sampleSizes, nbSamplesLoaded, *fastCoverParams);
             }
         }
         if (ZDICT_isError(dictSize)) {