From b6b8aff1557bcb4c75b54fdd97d83087d079a2ec Mon Sep 17 00:00:00 2001
From: Kim Albertsson <ketost@gmail.com>
Date: Fri, 2 Nov 2018 15:51:10 +0100
Subject: [PATCH] [TMVA] CV Stratified -- Add test for stratified splitting

---
 .../TestCrossValidationSplitting.cxx          | 81 +++++++++++++++++++
 1 file changed, 81 insertions(+)

diff --git a/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx b/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx
index f4fe9c61be9..772780dea64 100644
--- a/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx
+++ b/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx
@@ -226,6 +226,53 @@ bool testFold(DataLoader *d, id_vec_t ids, CvSplit &split, UInt_t iFold)
    return true;
 }
 
+/*
+ * Checks that the spread of the number of events of a particular class is at
+ * most 1 over all the folds. This is the core of the stratified splitting.
+ */
+bool testStratified(DataLoader *d, CvSplit &split, UInt_t numFolds)
+{
+   DataSet *ds = d->GetDataSetInfo().GetDataSet();
+
+   std::vector<UInt_t> nSigFolds;
+   std::vector<UInt_t> nBkgFolds;
+
+   for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
+      d->PrepareFoldDataSet(split, iFold, Types::kTraining);
+
+      // Get the number events per class in a fold
+      UInt_t nSignal = 0;
+      UInt_t nBackground = 0;
+      UInt_t nTotal = 0;
+      for (auto &ev : ds->GetEventCollection(Types::kTesting)) {
+         UInt_t classid = ev->GetClass();
+         if (classid == d->GetDataSetInfo().GetSignalClassIndex()) {
+            ++nSignal;
+         } else {
+            ++nBackground;
+         }
+         ++nTotal;
+      }
+
+      nSigFolds.push_back(nSignal);
+      nBkgFolds.push_back(nBackground);
+
+      std::cout << "Stats for fold " << iFold << " sig/bkg/tot: " << nSignal
+                << "/" << nBackground << "/" << nTotal << std::endl;
+   }
+
+   // Check the spread
+   Int_t minSig = *std::min_element(nSigFolds.begin(), nSigFolds.end());
+   Int_t maxSig = *std::max_element(nSigFolds.begin(), nSigFolds.end());
+   Int_t minBkg = *std::min_element(nBkgFolds.begin(), nBkgFolds.end());
+   Int_t maxBkg = *std::max_element(nBkgFolds.begin(), nBkgFolds.end());
+
+   EXPECT_LE((maxSig-minSig), 1);
+   EXPECT_LE((maxBkg-minBkg), 1);
+
+   return true;
+}
+
 } // End namespace TMVA
 
 TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator)
@@ -265,3 +312,37 @@ TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator)
    testFold(d, ids, split, 0);
    testFold(d, ids, split, 1);
 }
+
+TEST(CrossValidationSplitting, TrainingSetSplitRandomStratified)
+{
+   TMVA::Tools::Instance();
+
+   // Test for unbalanced classes
+   const UInt_t NUM_FOLDS = 3;
+   const UInt_t nPointsSig = 110;
+   const UInt_t nPointsBkg = 10;
+
+   // Create DataSet
+   TMVA::MsgLogger::InhibitOutput();
+   data_t data_class0 = TMVA::createData(nPointsSig, 0);
+   data_t data_class1 = TMVA::createData(nPointsBkg, 100);
+
+   TMVA::DataLoader *d = new TMVA::DataLoader("dataset");
+
+   d->AddSignalTree(std::get<1>(data_class0));
+   d->AddBackgroundTree(std::get<1>(data_class1));
+
+   d->AddVariable("x", 'D');
+   d->AddSpectator("id", "id", "");
+   d->PrepareTrainingAndTestTree(
+      "", Form("SplitMode=Block:nTrain_Signal=%i:nTrain_Background=%i:!V", nPointsSig, nPointsBkg));
+
+   d->GetDataSetInfo().GetDataSet(); // Force creation of dataset.
+   TMVA::MsgLogger::EnableOutput();
+
+   TMVA::CvSplitKFolds split{NUM_FOLDS, "", kTRUE, 0};
+   d->MakeKFoldDataSet(split);
+
+   // Actual test
+   testStratified(d, split, NUM_FOLDS);
+}
\ No newline at end of file
-- 
GitLab