Skip to content
Snippets Groups Projects
Commit f84a58f1 authored by Kim Albertsson's avatar Kim Albertsson Committed by Lorenzo Moneta
Browse files

[TMVA] BDTG -- Parallelise `UpdateTargets` for multiclass case

parent ef9d9eb0
No related branches found
No related tags found
No related merge requests found
......@@ -1438,6 +1438,57 @@ void TMVA::MethodBDT::UpdateTargets(std::vector<const TMVA::Event*>& eventSample
if (DoMulticlass()) {
UInt_t nClasses = DataInfo().GetNClasses();
Bool_t isLastClass = (cls == nClasses - 1);
#ifdef R__USE_IMT
//
// This is the multi-threaded multiclass version
//
// Note: we only need to update the predicted probabilities every
// `nClasses` tree. Let's call a set of `nClasses` trees a "round". Thus
// the algortihm is split in two parts `update_residuals` and
// `update_residuals_last` where the latter is inteded to be run instead
// of the former for the last tree in a "round".
//
std::map<const TMVA::Event *, std::vector<double>> & residuals = this->fResiduals;
DecisionTree & lastTree = *(this->fForest.back());
auto update_residuals = [&residuals, &lastTree, cls](const TMVA::Event * e) {
residuals[e].at(cls) += lastTree.CheckEvent(e, kFALSE);
};
auto update_residuals_last = [&residuals, &lastTree, cls, nClasses](const TMVA::Event * e) {
residuals[e].at(cls) += lastTree.CheckEvent(e, kFALSE);
auto &residualsThisEvent = residuals[e];
std::vector<Double_t> expCache(nClasses, 0.0);
std::transform(residualsThisEvent.begin(),
residualsThisEvent.begin() + nClasses,
expCache.begin(), [](Double_t d) { return exp(d); });
Double_t exp_sum = std::accumulate(expCache.begin(),
expCache.begin() + nClasses,
0.0);
for (UInt_t i = 0; i < nClasses; i++) {
Double_t p_cls = expCache[i] / exp_sum;
Double_t res = (e->GetClass() == i) ? (1.0 - p_cls) : (-p_cls);
const_cast<TMVA::Event *>(e)->SetTarget(i, res);
}
};
if (isLastClass) {
TMVA::Config::Instance().GetThreadExecutor()
.Foreach(update_residuals_last, eventSample);
} else {
TMVA::Config::Instance().GetThreadExecutor()
.Foreach(update_residuals, eventSample);
}
#else
//
// Single-threaded multiclass version
//
std::vector<Double_t> expCache;
if (isLastClass) {
expCache.resize(nClasses);
......@@ -1463,6 +1514,7 @@ void TMVA::MethodBDT::UpdateTargets(std::vector<const TMVA::Event*>& eventSample
}
}
}
#endif
} else {
std::map<const TMVA::Event *, std::vector<double>> & residuals = this->fResiduals;
DecisionTree & lastTree = *(this->fForest.back());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment