Skip to content
Snippets Groups Projects
Commit 610ef010 authored by Kim Albertsson's avatar Kim Albertsson Committed by Lorenzo Moneta
Browse files

[TMVA] CV modernise -- range based for-loops

parent e53085a6
No related branches found
No related tags found
No related merge requests found
......@@ -140,7 +140,7 @@ public:
void Evaluate();
private:
CrossValidationFoldResult ProcessFold(UInt_t iFold, UInt_t iMethod);
CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo);
Types::EAnalysisType fAnalysisType;
TString fAnalysisTypeStr;
......
......@@ -387,18 +387,14 @@ void TMVA::CrossValidation::SetSplitExpr(TString splitExpr)
/// @param iFold fold to evaluate
///
TMVA::CrossValidationFoldResult TMVA::CrossValidation::ProcessFold(UInt_t iFold, UInt_t iMethod)
TMVA::CrossValidationFoldResult TMVA::CrossValidation::ProcessFold(UInt_t iFold, const OptionMap & methodInfo)
{
TString methodName = fMethods[iMethod].GetValue<TString>("MethodName");
TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
TString methodOptions = fMethods[iMethod].GetValue<TString>("MethodOptions");
TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
TString foldTitle = methodTitle + ("_fold") + (iFold + 1);
Log() << kDEBUG << "Fold (" << methodTitle << "): " << iFold << Endl;
// Get specific fold of dataset and setup method
TString foldTitle = methodTitle;
foldTitle += "_fold";
foldTitle += iFold + 1;
Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
// Only used if fFoldOutputFile == true
TFile *foldOutputFile = nullptr;
......@@ -411,7 +407,7 @@ TMVA::CrossValidationFoldResult TMVA::CrossValidation::ProcessFold(UInt_t iFold,
}
fDataLoader->PrepareFoldDataSet(*fSplit.get(), iFold, TMVA::Types::kTraining);
MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodName, foldTitle, methodOptions);
MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
// Train method (train method and eval train set)
Event::SetIsTraining(kTRUE);
......@@ -481,11 +477,11 @@ void TMVA::CrossValidation::Evaluate()
}
fResults.reserve(fMethods.size());
for (UInt_t iMethod = 0; iMethod < fMethods.size(); iMethod++) {
for (auto & methodInfo : fMethods) {
CrossValidationResult result{fNumFolds};
TString methodTypeName = fMethods[iMethod].GetValue<TString>("MethodName");
TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
if (methodTypeName == "") {
Log() << kFATAL << "No method booked for cross-validation" << Endl;
......@@ -502,15 +498,15 @@ void TMVA::CrossValidation::Evaluate()
}
if (nWorkers == 1) {
for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
auto fold_result = ProcessFold(iFold, iMethod);
auto fold_result = ProcessFold(iFold, methodInfo);
result.Fill(fold_result);
}
} else {
ROOT::TProcessExecutor workers(nWorkers);
std::vector<CrossValidationFoldResult> result_vector;
auto workItem = [this, iMethod](UInt_t iFold) {
return ProcessFold(iFold, iMethod);
auto workItem = [this, methodInfo](UInt_t iFold) {
return ProcessFold(iFold, methodInfo);
};
result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
......@@ -544,12 +540,12 @@ void TMVA::CrossValidation::Evaluate()
fDataLoader->RecombineKFoldDataSet(*fSplit.get());
// "Eval" on training set
for (UInt_t iMethod = 0; iMethod < fMethods.size(); iMethod++) {
TString methodTypeName = fMethods[iMethod].GetValue<TString>("MethodName");
TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
for (auto & methodInfo : fMethods) {
TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
IMethod *method_interface = fFactory->GetMethod(fDataLoader.get()->GetName(), methodTitle);
MethodCrossValidation *method = dynamic_cast<MethodCrossValidation *>(method_interface);
auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
if (fOutputFile) {
fFactory->WriteDataInformation(method->fDataSetInfo);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment