Skip to content
Snippets Groups Projects
Commit b6b8aff1 authored by Kim Albertsson's avatar Kim Albertsson Committed by Axel Naumann
Browse files

[TMVA] CV Stratified -- Add test for stratified splitting

parent 7522366d
Branches
Tags
No related merge requests found
...@@ -226,6 +226,53 @@ bool testFold(DataLoader *d, id_vec_t ids, CvSplit &split, UInt_t iFold) ...@@ -226,6 +226,53 @@ bool testFold(DataLoader *d, id_vec_t ids, CvSplit &split, UInt_t iFold)
return true; return true;
} }
/*
* Checks that the spread of the number of events of a particular class is at
* most 1 over all the folds. This is the core of the stratified splitting.
*/
bool testStratified(DataLoader *d, CvSplit &split, UInt_t numFolds)
{
DataSet *ds = d->GetDataSetInfo().GetDataSet();
std::vector<UInt_t> nSigFolds;
std::vector<UInt_t> nBkgFolds;
for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
d->PrepareFoldDataSet(split, iFold, Types::kTraining);
// Get the number events per class in a fold
UInt_t nSignal = 0;
UInt_t nBackground = 0;
UInt_t nTotal = 0;
for (auto &ev : ds->GetEventCollection(Types::kTesting)) {
UInt_t classid = ev->GetClass();
if (classid == d->GetDataSetInfo().GetSignalClassIndex()) {
++nSignal;
} else {
++nBackground;
}
++nTotal;
}
nSigFolds.push_back(nSignal);
nBkgFolds.push_back(nBackground);
std::cout << "Stats for fold " << iFold << " sig/bkg/tot: " << nSignal
<< "/" << nBackground << "/" << nTotal << std::endl;
}
// Check the spread
Int_t minSig = *std::min_element(nSigFolds.begin(), nSigFolds.end());
Int_t maxSig = *std::max_element(nSigFolds.begin(), nSigFolds.end());
Int_t minBkg = *std::min_element(nBkgFolds.begin(), nBkgFolds.end());
Int_t maxBkg = *std::max_element(nBkgFolds.begin(), nBkgFolds.end());
EXPECT_LE((maxSig-minSig), 1);
EXPECT_LE((maxBkg-minBkg), 1);
return true;
}
} // End namespace TMVA } // End namespace TMVA
TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator) TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator)
...@@ -265,3 +312,37 @@ TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator) ...@@ -265,3 +312,37 @@ TEST(CrossValidationSplitting, TrainingSetSplitOnSpectator)
testFold(d, ids, split, 0); testFold(d, ids, split, 0);
testFold(d, ids, split, 1); testFold(d, ids, split, 1);
} }
TEST(CrossValidationSplitting, TrainingSetSplitRandomStratified)
{
TMVA::Tools::Instance();
// Test for unbalanced classes
const UInt_t NUM_FOLDS = 3;
const UInt_t nPointsSig = 110;
const UInt_t nPointsBkg = 10;
// Create DataSet
TMVA::MsgLogger::InhibitOutput();
data_t data_class0 = TMVA::createData(nPointsSig, 0);
data_t data_class1 = TMVA::createData(nPointsBkg, 100);
TMVA::DataLoader *d = new TMVA::DataLoader("dataset");
d->AddSignalTree(std::get<1>(data_class0));
d->AddBackgroundTree(std::get<1>(data_class1));
d->AddVariable("x", 'D');
d->AddSpectator("id", "id", "");
d->PrepareTrainingAndTestTree(
"", Form("SplitMode=Block:nTrain_Signal=%i:nTrain_Background=%i:!V", nPointsSig, nPointsBkg));
d->GetDataSetInfo().GetDataSet(); // Force creation of dataset.
TMVA::MsgLogger::EnableOutput();
TMVA::CvSplitKFolds split{NUM_FOLDS, "", kTRUE, 0};
d->MakeKFoldDataSet(split);
// Actual test
testStratified(d, split, NUM_FOLDS);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment