From b6b8aff1557bcb4c75b54fdd97d83087d079a2ec Mon Sep 17 00:00:00 2001 From: Kim Albertsson <ketost@gmail.com> Date: Fri, 2 Nov 2018 15:51:10 +0100 Subject: [PATCH] [TMVA] CV Stratified -- Add test for stratified splitting --- .../TestCrossValidationSplitting.cxx | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx b/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx index f4fe9c61be9..772780dea64 100644 --- a/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx +++ b/tmva/tmva/test/crossvalidation/TestCrossValidationSplitting.cxx @@ -226,6 +226,53 @@ bool testFold(DataLoader *d, id_vec_t ids, CvSplit &split, UInt_t iFold) return true; } +/* + * Checks that the spread of the number of events of a particular class is at + * most 1 over all the folds. This is the core of the stratified splitting. + */ +bool testStratified(DataLoader *d, CvSplit &split, UInt_t numFolds) +{ + DataSet *ds = d->GetDataSetInfo().GetDataSet(); + + std::vector<UInt_t> nSigFolds; + std::vector<UInt_t> nBkgFolds; + + for (UInt_t iFold = 0; iFold < numFolds; ++iFold) { + d->PrepareFoldDataSet(split, iFold, Types::kTraining); + + // Get the number events per class in a fold + UInt_t nSignal = 0; + UInt_t nBackground = 0; + UInt_t nTotal = 0; + for (auto &ev : ds->GetEventCollection(Types::kTesting)) { + UInt_t classid = ev->GetClass(); + if (classid == d->GetDataSetInfo().GetSignalClassIndex()) { + ++nSignal; + } else { + ++nBackground; + } + ++nTotal; + } + + nSigFolds.push_back(nSignal); + nBkgFolds.push_back(nBackground); + + std::cout << "Stats for fold " << iFold << " sig/bkg/tot: " << nSignal + << "/" << nBackground << "/" << nTotal << std::endl; + } + + // Check the spread + Int_t minSig = *std::min_element(nSigFolds.begin(), nSigFolds.end()); + Int_t maxSig = *std::max_element(nSigFolds.begin(), nSigFolds.end()); + Int_t minBkg = *std::min_element(nBkgFolds.begin(), nBkgFolds.end()); + Int_t maxBkg = *std::max_element(nBkgFolds.begin(), nBkgFolds.end()); + + EXPECT_LE((maxSig-minSig), 1); + EXPECT_LE((maxBkg-minBkg), 1); + + return true; +} + } // End namespace TMVA TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator) @@ -265,3 +312,37 @@ TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator) testFold(d, ids, split, 0); testFold(d, ids, split, 1); } + +TEST(CrossValidationSplitting, TrainingSetSplitRandomStratified) +{ + TMVA::Tools::Instance(); + + // Test for unbalanced classes + const UInt_t NUM_FOLDS = 3; + const UInt_t nPointsSig = 110; + const UInt_t nPointsBkg = 10; + + // Create DataSet + TMVA::MsgLogger::InhibitOutput(); + data_t data_class0 = TMVA::createData(nPointsSig, 0); + data_t data_class1 = TMVA::createData(nPointsBkg, 100); + + TMVA::DataLoader *d = new TMVA::DataLoader("dataset"); + + d->AddSignalTree(std::get<1>(data_class0)); + d->AddBackgroundTree(std::get<1>(data_class1)); + + d->AddVariable("x", 'D'); + d->AddSpectator("id", "id", ""); + d->PrepareTrainingAndTestTree( + "", Form("SplitMode=Block:nTrain_Signal=%i:nTrain_Background=%i:!V", nPointsSig, nPointsBkg)); + + d->GetDataSetInfo().GetDataSet(); // Force creation of dataset. + TMVA::MsgLogger::EnableOutput(); + + TMVA::CvSplitKFolds split{NUM_FOLDS, "", kTRUE, 0}; + d->MakeKFoldDataSet(split); + + // Actual test + testStratified(d, split, NUM_FOLDS); +} \ No newline at end of file -- GitLab