Skip to content
Snippets Groups Projects
Commit 7743cc74 authored by Kim Albertsson's avatar Kim Albertsson Committed by Axel Naumann
Browse files

[TMVA] CV Stratified -- Fix whitespace issues

parent b3c2ae93
Branches
Tags
No related merge requests found
......@@ -100,7 +100,7 @@ public:
private:
std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
private:
UInt_t fSeed;
TString fSplitExprString; //! Expression used to split data into folds. Should output values between 0 and numFolds.
......
......@@ -331,7 +331,6 @@ TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFol
Bool_t useSplitExpr = not(fSplitExpr == nullptr or fSplitExprString == "");
if (useSplitExpr) {
// Deterministic split
for (ULong64_t i = 0; i < nEntries; i++) {
TMVA::Event *ev = oldSet[i];
......@@ -339,51 +338,53 @@ TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFol
tempSets.at((UInt_t)iFold).push_back(ev);
}
} else {
std::vector<UInt_t> fOrigToFoldMapping;
if(not fStratified){
// Random split
std::vector<UInt_t> fOrigToFoldMapping;
fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
if(fStratified == kFALSE){
// Random split
fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
UInt_t iFold = fOrigToFoldMapping[iEvent];
TMVA::Event *ev = oldSet[iEvent];
tempSets.at(iFold).push_back(ev);
fEventToFoldMapping[ev] = iFold;
}
}
else{
}
} else {
// Stratified Split
std::vector<std::vector<TMVA::Event *>> oldSets;
oldSets.reserve(numClasses);
for(UInt_t iClass = 0; iClass < numClasses; iClass++){
oldSets.emplace_back();
//find a way to get number of events in each class
oldSets.reserve(nEntries);
oldSets.reserve(nEntries);
}
for(UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
// check the class of event and add to its vector of events
TMVA::Event *ev = oldSet[iEvent];
UInt_t iClass = ev->GetClass();
oldSets.at(iClass).push_back(ev);
}
for(UInt_t i = 0; i<numClasses; ++i){
// Shuffle each vector individually
TMVA::RandomGenerator<TRandom3> rng(fSeed);
std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
}
for(UInt_t i = 0; i<numClasses; ++i){
fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
for(UInt_t i = 0; i<numClasses; ++i) {
std::vector<UInt_t> fOrigToFoldMapping;
fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
for (UInt_t iEvent = 0; iEvent < oldSets.at(i).size(); ++iEvent) {
UInt_t iFold = fOrigToFoldMapping[iEvent];
TMVA::Event *ev = oldSets.at(i)[iEvent];
tempSets.at(iFold).push_back(ev);
fEventToFoldMapping[ev] = iFold;
}
}
}
}
}
return tempSets;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment