From 2948dd85f9847ce30bd50a2bd2680c3d850f093a Mon Sep 17 00:00:00 2001 From: Jonas Rembser <jonas.rembser@cern.ch> Date: Fri, 9 Apr 2021 23:38:43 +0200 Subject: [PATCH] [RF] Check if ranges are not overlapping in multi-range fit Events are double counted if one accidentally defines overlapping ranges and uses them in a multi-range fit. This happened for example in Jira issue ROOT-9548 where the whole dataset was double counted, leading to the parameter uncertainties being underestimated by a factor of sqrt(2). That situation should be avoided. This commit introduces a check for overlapping ranges before the multi-range likelihood is created. --- roofit/roofitcore/inc/RooHelpers.h | 6 +- roofit/roofitcore/src/RooAbsPdf.cxx | 38 ++++++------ roofit/roofitcore/src/RooHelpers.cxx | 90 +++++++++++++++++++++++++--- 3 files changed, 105 insertions(+), 29 deletions(-) diff --git a/roofit/roofitcore/inc/RooHelpers.h b/roofit/roofitcore/inc/RooHelpers.h index aa5576e6419..1879f4e84cb 100644 --- a/roofit/roofitcore/inc/RooHelpers.h +++ b/roofit/roofitcore/inc/RooHelpers.h @@ -26,6 +26,9 @@ #include <string> #include <utility> +class RooAbsPdf; +class RooAbsData; + namespace RooHelpers { @@ -94,7 +97,7 @@ std::vector<std::string> tokenise(const std::string &str, const std::string &del /// Check if the parameters have a range, and warn if the range extends below / above the set limits. void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_list<const RooAbsReal*> pars, double min = -std::numeric_limits<double>::max(), double max = std::numeric_limits<double>::max(), - bool limitsInAllowedRange = false, std::string extraMessage = ""); + bool limitsInAllowedRange = false, std::string const& extraMessage = ""); /// Disable all caches for sub-branches in an expression tree. @@ -116,6 +119,7 @@ struct DisableCachingRAII { std::pair<double, double> getRangeOrBinningInterval(RooAbsArg const* arg, const char* rangeName); +bool checkIfRangesOverlap(RooAbsPdf const& pdf, RooAbsData const& data, std::vector<std::string> const& rangeNames); } diff --git a/roofit/roofitcore/src/RooAbsPdf.cxx b/roofit/roofitcore/src/RooAbsPdf.cxx index 8058e1bd70c..9688b8970de 100644 --- a/roofit/roofitcore/src/RooAbsPdf.cxx +++ b/roofit/roofitcore/src/RooAbsPdf.cxx @@ -187,6 +187,7 @@ called for each data event. #include <iostream> #include <string> #include <cmath> +#include <stdexcept> using namespace std; @@ -1099,40 +1100,35 @@ RooAbsReal* RooAbsPdf::createNLL(RooAbsData& data, const RooLinkedList& cmdList) RooAbsReal::setEvalErrorLoggingMode(RooAbsReal::CollectErrors) ; RooAbsReal* nll ; string baseName = Form("nll_%s_%s",GetName(),data.GetName()) ; + RooAbsTestStatistic::Configuration cfg; + cfg.addCoefRangeName = addCoefRangeName ? addCoefRangeName : ""; + cfg.nCPU = numcpu; + cfg.interleave = interl; + cfg.verbose = verbose; + cfg.splitCutRange = static_cast<bool>(splitr); + cfg.cloneInputData = static_cast<bool>(cloneData); + cfg.integrateOverBinsPrecision = pc.getDouble("IntegrateBins"); + cfg.binnedL = false; if (!rangeName || strchr(rangeName,',')==0) { // Simple case: default range, or single restricted range //cout<<"FK: Data test 1: "<<data.sumEntries()<<endl; - RooAbsTestStatistic::Configuration cfg; cfg.rangeName = rangeName ? rangeName : ""; - cfg.addCoefRangeName = addCoefRangeName ? addCoefRangeName : ""; - cfg.nCPU = numcpu; - cfg.interleave = interl; - cfg.verbose = verbose; - cfg.splitCutRange = static_cast<bool>(splitr); - cfg.cloneInputData = static_cast<bool>(cloneData); - cfg.integrateOverBinsPrecision = pc.getDouble("IntegrateBins"); - cfg.binnedL = false; - auto theNLL = new RooNLLVar(baseName.c_str(),"-log(likelihood)",*this,data,projDeps,std::move(cfg), ext); + auto theNLL = new RooNLLVar(baseName.c_str(),"-log(likelihood)",*this,data,projDeps,cfg, ext); theNLL->batchMode(pc.getInt("BatchMode")); nll = theNLL; } else { // Composite case: multiple ranges RooArgList nllList ; auto tokens = RooHelpers::tokenise(rangeName, ","); + if (RooHelpers::checkIfRangesOverlap(*this, data, tokens)) { + throw std::runtime_error( + std::string("Error in RooAbsPdf::createNLL! The ranges ") + rangeName + " are overlapping!"); + } for (const auto& token : tokens) { - RooAbsTestStatistic::Configuration cfg; cfg.rangeName = token; - cfg.addCoefRangeName = addCoefRangeName ? addCoefRangeName : ""; - cfg.nCPU = numcpu; - cfg.interleave = interl; - cfg.verbose = verbose; - cfg.splitCutRange = static_cast<bool>(splitr); - cfg.cloneInputData = static_cast<bool>(cloneData); - cfg.integrateOverBinsPrecision = pc.getDouble("IntegrateBins"); - cfg.binnedL = false; - auto nllComp = new RooNLLVar(Form("%s_%s",baseName.c_str(),token.c_str()),"-log(likelihood)", - *this,data,projDeps,std::move(cfg),ext); + auto nllComp = new RooNLLVar((baseName + "_" + token).c_str(),"-log(likelihood)", + *this,data,projDeps,cfg,ext); nllComp->batchMode(pc.getInt("BatchMode")); nllList.add(*nllComp) ; } diff --git a/roofit/roofitcore/src/RooHelpers.cxx b/roofit/roofitcore/src/RooHelpers.cxx index fb80417e720..fad07f03ec7 100644 --- a/roofit/roofitcore/src/RooHelpers.cxx +++ b/roofit/roofitcore/src/RooHelpers.cxx @@ -15,6 +15,10 @@ *****************************************************************************/ #include "RooHelpers.h" +#include "RooAbsPdf.h" +#include "RooAbsData.h" +#include "RooDataHist.h" +#include "RooDataSet.h" #include "RooAbsRealLValue.h" #include "TClass.h" @@ -115,7 +119,7 @@ HijackMessageStream::~HijackMessageStream() { /// \param[in] limitsInAllowedRange If true, the limits passed as parameters are part of the allowed range. /// \param[in] extraMessage Message that should be appended to the warning. void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_list<const RooAbsReal*> pars, - double min, double max, bool limitsInAllowedRange, std::string extraMessage) { + double min, double max, bool limitsInAllowedRange, std::string const& extraMessage) { const char openBr = limitsInAllowedRange ? '[' : '('; const char closeBr = limitsInAllowedRange ? ']' : ')'; @@ -145,6 +149,17 @@ void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_lis } +namespace { + std::pair<double, double> getBinningInterval(RooAbsBinning const& binning) { + if (!binning.isParameterized()) { + return {binning.lowBound(), binning.highBound()}; + } else { + return {binning.lowBoundFunc()->getVal(), binning.highBoundFunc()->getVal()}; + } + } +} // namespace + + /// Get the lower and upper bound of parameter range if arg can be casted to RooAbsRealLValue. /// If no range with rangeName is defined for the argument, this will check if a binning of the /// same name exists and return the interval covered by the binning. @@ -155,18 +170,79 @@ void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_lis std::pair<double, double> getRangeOrBinningInterval(RooAbsArg const* arg, const char* rangeName) { auto rlv = dynamic_cast<RooAbsRealLValue const*>(arg); if (rlv) { - const RooAbsBinning* binning = rlv->getBinningPtr(rangeName); if (rangeName && rlv->hasRange(rangeName)) { return {rlv->getMin(rangeName), rlv->getMax(rangeName)}; - } else if (binning) { - if (!binning->isParameterized()) { - return {binning->lowBound(), binning->highBound()}; + } else if (auto binning = rlv->getBinningPtr(rangeName)) { + return getBinningInterval(*binning); + } + } + return {-std::numeric_limits<double>::infinity(), +std::numeric_limits<double>::infinity()}; +} + + +/// Check if there is any overlap when a list of ranges is applied to a set of observables. +/// \param[in] arg RooAbsCollection with the observables to check for overlap. +/// \param[in] rangeName The names of the ranges. +bool checkIfRangesOverlap(RooAbsPdf const& pdf, RooAbsData const& data, std::vector<std::string> const& rangeNames) { + + auto observables = *pdf.getObservables(data); + + auto getLimits = [&](RooAbsRealLValue const& rlv, const char* rangeName) { + + // RooDataHistCase + if(dynamic_cast<RooDataHist const*>(&data)) { + if (auto binning = rlv.getBinningPtr(rangeName)) { + return getBinningInterval(*binning); } else { - return {binning->lowBoundFunc()->getVal(), binning->highBoundFunc()->getVal()}; + // default binning if range is not defined + return getBinningInterval(*rlv.getBinningPtr(nullptr)); + } + } + + // RooDataSet and other cases + if (rlv.hasRange(rangeName)) { + return std::pair<double, double>{rlv.getMin(rangeName), rlv.getMax(rangeName)}; + } + // default range if range with given name is not defined + return std::pair<double, double>{rlv.getMin(), rlv.getMax()}; + }; + + auto nObs = observables.size(); + auto nRanges = rangeNames.size(); + + // cache the range limits in a flat vector + std::vector<std::pair<double,double>> limits; + limits.reserve(nRanges * nObs); + + for (auto const& range : rangeNames) { + for (auto const& obs : observables) { + auto rlv = dynamic_cast<RooAbsRealLValue const*>(obs); + if(!rlv) { + throw std::logic_error("Classes that represent observables are expected to inherit from RooAbsRealLValue!"); } + limits.push_back(getLimits(*rlv, range.c_str())); } } - return {-std::numeric_limits<double>::infinity(), +std::numeric_limits<double>::infinity()}; + + // loop over pairs of ranges + for(size_t ir1 = 0; ir1 < nRanges; ++ir1) { + for(size_t ir2 = ir1 + 1; ir2 < nRanges; ++ir2) { + + // Loop over observables. If all observables have overlapping limits for + // these ranges, the hypercubes defining the range are overlapping and we + // can return `true`. + size_t overlaps = 0; + for(size_t io1 = 0; io1 < nObs; ++io1) { + auto r1 = limits[ir1 * nObs + io1]; + auto r2 = limits[ir2 * nObs + io1]; + overlaps += (r1.second > r2.first && r1.first < r2.second) + || (r2.second > r1.first && r2.first < r1.second); + } + if(overlaps == nObs) return true; + } + } + + return false; } -- GitLab