From 2948dd85f9847ce30bd50a2bd2680c3d850f093a Mon Sep 17 00:00:00 2001
From: Jonas Rembser <jonas.rembser@cern.ch>
Date: Fri, 9 Apr 2021 23:38:43 +0200
Subject: [PATCH] [RF] Check if ranges are not overlapping in multi-range fit

Events are double counted if one accidentally defines overlapping ranges
and uses them in a multi-range fit. This happened for example in Jira
issue ROOT-9548 where the whole dataset was double counted, leading to
the parameter uncertainties being underestimated by a factor of sqrt(2).

That situation should be avoided. This commit introduces a check for
overlapping ranges before the multi-range likelihood is created.
---
 roofit/roofitcore/inc/RooHelpers.h   |  6 +-
 roofit/roofitcore/src/RooAbsPdf.cxx  | 38 ++++++------
 roofit/roofitcore/src/RooHelpers.cxx | 90 +++++++++++++++++++++++++---
 3 files changed, 105 insertions(+), 29 deletions(-)

diff --git a/roofit/roofitcore/inc/RooHelpers.h b/roofit/roofitcore/inc/RooHelpers.h
index aa5576e6419..1879f4e84cb 100644
--- a/roofit/roofitcore/inc/RooHelpers.h
+++ b/roofit/roofitcore/inc/RooHelpers.h
@@ -26,6 +26,9 @@
 #include <string>
 #include <utility>
 
+class RooAbsPdf;
+class RooAbsData;
+
 
 namespace RooHelpers {
 
@@ -94,7 +97,7 @@ std::vector<std::string> tokenise(const std::string &str, const std::string &del
 /// Check if the parameters have a range, and warn if the range extends below / above the set limits.
 void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_list<const RooAbsReal*> pars,
     double min = -std::numeric_limits<double>::max(), double max = std::numeric_limits<double>::max(),
-    bool limitsInAllowedRange = false, std::string extraMessage = "");
+    bool limitsInAllowedRange = false, std::string const& extraMessage = "");
 
 
 /// Disable all caches for sub-branches in an expression tree.
@@ -116,6 +119,7 @@ struct DisableCachingRAII {
 
 std::pair<double, double> getRangeOrBinningInterval(RooAbsArg const* arg, const char* rangeName);
 
+bool checkIfRangesOverlap(RooAbsPdf const& pdf, RooAbsData const& data, std::vector<std::string> const& rangeNames);
 
 }
 
diff --git a/roofit/roofitcore/src/RooAbsPdf.cxx b/roofit/roofitcore/src/RooAbsPdf.cxx
index 8058e1bd70c..9688b8970de 100644
--- a/roofit/roofitcore/src/RooAbsPdf.cxx
+++ b/roofit/roofitcore/src/RooAbsPdf.cxx
@@ -187,6 +187,7 @@ called for each data event.
 #include <iostream>
 #include <string>
 #include <cmath>
+#include <stdexcept>
 
 using namespace std;
 
@@ -1099,40 +1100,35 @@ RooAbsReal* RooAbsPdf::createNLL(RooAbsData& data, const RooLinkedList& cmdList)
   RooAbsReal::setEvalErrorLoggingMode(RooAbsReal::CollectErrors) ;
   RooAbsReal* nll ;
   string baseName = Form("nll_%s_%s",GetName(),data.GetName()) ;
+  RooAbsTestStatistic::Configuration cfg;
+  cfg.addCoefRangeName = addCoefRangeName ? addCoefRangeName : "";
+  cfg.nCPU = numcpu;
+  cfg.interleave = interl;
+  cfg.verbose = verbose;
+  cfg.splitCutRange = static_cast<bool>(splitr);
+  cfg.cloneInputData = static_cast<bool>(cloneData);
+  cfg.integrateOverBinsPrecision = pc.getDouble("IntegrateBins");
+  cfg.binnedL = false;
   if (!rangeName || strchr(rangeName,',')==0) {
     // Simple case: default range, or single restricted range
     //cout<<"FK: Data test 1: "<<data.sumEntries()<<endl;
 
-    RooAbsTestStatistic::Configuration cfg;
     cfg.rangeName = rangeName ? rangeName : "";
-    cfg.addCoefRangeName = addCoefRangeName ? addCoefRangeName : "";
-    cfg.nCPU = numcpu;
-    cfg.interleave = interl;
-    cfg.verbose = verbose;
-    cfg.splitCutRange = static_cast<bool>(splitr);
-    cfg.cloneInputData = static_cast<bool>(cloneData);
-    cfg.integrateOverBinsPrecision = pc.getDouble("IntegrateBins");
-    cfg.binnedL = false;
-    auto theNLL = new RooNLLVar(baseName.c_str(),"-log(likelihood)",*this,data,projDeps,std::move(cfg), ext);
+    auto theNLL = new RooNLLVar(baseName.c_str(),"-log(likelihood)",*this,data,projDeps,cfg, ext);
     theNLL->batchMode(pc.getInt("BatchMode"));
     nll = theNLL;
   } else {
     // Composite case: multiple ranges
     RooArgList nllList ;
     auto tokens = RooHelpers::tokenise(rangeName, ",");
+    if (RooHelpers::checkIfRangesOverlap(*this, data, tokens)) {
+      throw std::runtime_error(
+              std::string("Error in RooAbsPdf::createNLL! The ranges ") + rangeName + " are overlapping!");
+    }
     for (const auto& token : tokens) {
-      RooAbsTestStatistic::Configuration cfg;
       cfg.rangeName = token;
-      cfg.addCoefRangeName = addCoefRangeName ? addCoefRangeName : "";
-      cfg.nCPU = numcpu;
-      cfg.interleave = interl;
-      cfg.verbose = verbose;
-      cfg.splitCutRange = static_cast<bool>(splitr);
-      cfg.cloneInputData = static_cast<bool>(cloneData);
-      cfg.integrateOverBinsPrecision = pc.getDouble("IntegrateBins");
-      cfg.binnedL = false;
-      auto nllComp = new RooNLLVar(Form("%s_%s",baseName.c_str(),token.c_str()),"-log(likelihood)",
-                                   *this,data,projDeps,std::move(cfg),ext);
+      auto nllComp = new RooNLLVar((baseName + "_" + token).c_str(),"-log(likelihood)",
+                                   *this,data,projDeps,cfg,ext);
       nllComp->batchMode(pc.getInt("BatchMode"));
       nllList.add(*nllComp) ;
     }
diff --git a/roofit/roofitcore/src/RooHelpers.cxx b/roofit/roofitcore/src/RooHelpers.cxx
index fb80417e720..fad07f03ec7 100644
--- a/roofit/roofitcore/src/RooHelpers.cxx
+++ b/roofit/roofitcore/src/RooHelpers.cxx
@@ -15,6 +15,10 @@
  *****************************************************************************/
 
 #include "RooHelpers.h"
+#include "RooAbsPdf.h"
+#include "RooAbsData.h"
+#include "RooDataHist.h"
+#include "RooDataSet.h"
 #include "RooAbsRealLValue.h"
 
 #include "TClass.h"
@@ -115,7 +119,7 @@ HijackMessageStream::~HijackMessageStream() {
 /// \param[in] limitsInAllowedRange If true, the limits passed as parameters are part of the allowed range.
 /// \param[in] extraMessage Message that should be appended to the warning.
 void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_list<const RooAbsReal*> pars,
-    double min, double max, bool limitsInAllowedRange, std::string extraMessage) {
+    double min, double max, bool limitsInAllowedRange, std::string const& extraMessage) {
   const char openBr = limitsInAllowedRange ? '[' : '(';
   const char closeBr = limitsInAllowedRange ? ']' : ')';
 
@@ -145,6 +149,17 @@ void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_lis
 }
 
 
+namespace {
+  std::pair<double, double> getBinningInterval(RooAbsBinning const& binning) {
+    if (!binning.isParameterized()) {
+      return {binning.lowBound(), binning.highBound()};
+    } else {
+      return {binning.lowBoundFunc()->getVal(), binning.highBoundFunc()->getVal()};
+    }
+  }
+} // namespace
+
+
 /// Get the lower and upper bound of parameter range if arg can be casted to RooAbsRealLValue.
 /// If no range with rangeName is defined for the argument, this will check if a binning of the
 /// same name exists and return the interval covered by the binning.
@@ -155,18 +170,79 @@ void checkRangeOfParameters(const RooAbsReal* callingClass, std::initializer_lis
 std::pair<double, double> getRangeOrBinningInterval(RooAbsArg const* arg, const char* rangeName) {
   auto rlv = dynamic_cast<RooAbsRealLValue const*>(arg);
   if (rlv) {
-    const RooAbsBinning* binning = rlv->getBinningPtr(rangeName);
     if (rangeName && rlv->hasRange(rangeName)) {
       return {rlv->getMin(rangeName), rlv->getMax(rangeName)};
-    } else if (binning) {
-      if (!binning->isParameterized()) {
-        return {binning->lowBound(), binning->highBound()};
+    } else if (auto binning = rlv->getBinningPtr(rangeName)) {
+      return getBinningInterval(*binning);
+    }
+  }
+  return {-std::numeric_limits<double>::infinity(), +std::numeric_limits<double>::infinity()};
+}
+
+
+/// Check if there is any overlap when a list of ranges is applied to a set of observables.
+/// \param[in] arg RooAbsCollection with the observables to check for overlap.
+/// \param[in] rangeName The names of the ranges.
+bool checkIfRangesOverlap(RooAbsPdf const& pdf, RooAbsData const& data, std::vector<std::string> const& rangeNames) {
+
+  auto observables = *pdf.getObservables(data);
+
+  auto getLimits = [&](RooAbsRealLValue const& rlv, const char* rangeName) {
+    
+    // RooDataHistCase
+    if(dynamic_cast<RooDataHist const*>(&data)) {
+      if (auto binning = rlv.getBinningPtr(rangeName)) {
+        return getBinningInterval(*binning);
       } else {
-        return {binning->lowBoundFunc()->getVal(), binning->highBoundFunc()->getVal()};
+        // default binning if range is not defined
+        return getBinningInterval(*rlv.getBinningPtr(nullptr));
+      }
+    }
+
+    // RooDataSet and other cases
+    if (rlv.hasRange(rangeName)) {
+      return std::pair<double, double>{rlv.getMin(rangeName), rlv.getMax(rangeName)};
+    }
+    // default range if range with given name is not defined
+    return std::pair<double, double>{rlv.getMin(), rlv.getMax()};
+  };
+
+  auto nObs = observables.size();
+  auto nRanges = rangeNames.size();
+
+  // cache the range limits in a flat vector
+  std::vector<std::pair<double,double>> limits;
+  limits.reserve(nRanges * nObs);
+
+  for (auto const& range : rangeNames) {
+    for (auto const& obs : observables) {
+      auto rlv = dynamic_cast<RooAbsRealLValue const*>(obs);
+      if(!rlv) {
+        throw std::logic_error("Classes that represent observables are expected to inherit from RooAbsRealLValue!");
       }
+      limits.push_back(getLimits(*rlv, range.c_str()));
     }
   }
-  return {-std::numeric_limits<double>::infinity(), +std::numeric_limits<double>::infinity()};
+
+  // loop over pairs of ranges
+  for(size_t ir1 = 0; ir1 < nRanges; ++ir1) {
+    for(size_t ir2 = ir1 + 1; ir2 < nRanges; ++ir2) {
+
+      // Loop over observables. If all observables have overlapping limits for
+      // these ranges, the hypercubes defining the range are overlapping and we
+      // can return `true`.
+      size_t overlaps = 0;
+      for(size_t io1 = 0; io1 < nObs; ++io1) {
+        auto r1 = limits[ir1 * nObs + io1];
+        auto r2 = limits[ir2 * nObs + io1];
+        overlaps += (r1.second > r2.first && r1.first < r2.second)
+                 || (r2.second > r1.first && r2.first < r1.second);
+      }
+      if(overlaps == nObs) return true;
+    }
+  }
+
+  return false;
 }
 
 
-- 
GitLab