diff --git a/tmva/tmva/src/MethodBDT.cxx b/tmva/tmva/src/MethodBDT.cxx index e059771c2daf4c8c12acad038219a62fffdab2d7..72f49525985e69f0b568496ddd7c54eb1f9f0eda 100644 --- a/tmva/tmva/src/MethodBDT.cxx +++ b/tmva/tmva/src/MethodBDT.cxx @@ -2502,6 +2502,18 @@ const std::vector<Float_t>& TMVA::MethodBDT::GetMulticlassValues() UInt_t nClasses = DataInfo().GetNClasses(); std::vector<Double_t> temp(nClasses); auto forestSize = fForest.size(); + + #ifdef R__USE_IMT + std::vector<TMVA::DecisionTree *> forest = fForest; + auto get_output = [&e, &forest, &temp, forestSize, nClasses](UInt_t iClass) { + for (UInt_t itree = iClass; itree < forestSize; itree += nClasses) { + temp[iClass] += forest[itree]->CheckEvent(e, kFALSE); + } + }; + + TMVA::Config::Instance().GetThreadExecutor() + .Foreach(get_output, ROOT::TSeqU(nClasses)); + #else // trees 0, nClasses, 2*nClasses, ... belong to class 0 // trees 1, nClasses+1, 2*nClasses+1, ... belong to class 1 and so forth UInt_t classOfTree = 0; @@ -2509,6 +2521,7 @@ const std::vector<Float_t>& TMVA::MethodBDT::GetMulticlassValues() temp[classOfTree] += fForest[itree]->CheckEvent(e, kFALSE); if (++classOfTree == nClasses) classOfTree = 0; // cheap modulo } + #endif // we want to calculate sum of exp(temp[j] - temp[i]) for all i,j (i!=j) // first calculate exp(), then replace minus with division.