From d3364d6e1718920eabe547018a353c3673645d58 Mon Sep 17 00:00:00 2001
From: Jonas Rembser <jonas.rembser@cern.ch>
Date: Wed, 22 Mar 2023 12:26:33 +0100
Subject: [PATCH] [RF][HS3] Cover also `HistoSys` in HS3 HistFactory test

---
 roofit/hs3/test/testHS3HistFactory.cxx | 86 ++++++++++++++------------
 1 file changed, 48 insertions(+), 38 deletions(-)

diff --git a/roofit/hs3/test/testHS3HistFactory.cxx b/roofit/hs3/test/testHS3HistFactory.cxx
index 3493a5f5f40..9c9a52a326c 100644
--- a/roofit/hs3/test/testHS3HistFactory.cxx
+++ b/roofit/hs3/test/testHS3HistFactory.cxx
@@ -18,31 +18,18 @@
 
 namespace {
 
-void toJSON(RooStats::HistFactory::Measurement &meas, std::string const &fname)
-{
-   RooStats::HistFactory::JSONTool tool{meas};
-   tool.PrintJSON(fname);
-}
-
-std::unique_ptr<RooWorkspace> toWS(RooStats::HistFactory::Measurement &meas)
-{
-   RooHelpers::LocalChangeMsgLevel changeMsgLvl(RooFit::WARNING);
-   return std::unique_ptr<RooWorkspace>{RooStats::HistFactory::MakeModelAndMeasurementFast(meas)};
-}
-
-std::unique_ptr<RooWorkspace> importToWS(std::string const &infile, std::string const &wsname)
-{
-   auto ws = std::make_unique<RooWorkspace>(wsname.c_str());
-   RooJSONFactoryWSTool tool{*ws};
-   tool.importJSON(infile);
-   return ws;
-}
+// If the JSON files should be written out for debugging purpose.
+const bool writeJsonFiles = false;
 
 void createInputFile(std::string const &inputFileName)
 {
 
    TH1F data("data", "data", 2, 1.0, 2.0);
+
    TH1F signal("signal", "signal histogram (pb)", 2, 1.0, 2.0);
+   TH1F systUncDo("shapeUnc_sigDo", "signal shape uncert.", 2, 1.0, 2.0);
+   TH1F systUncUp("shapeUnc_sigUp", "signal shape uncert.", 2, 1.0, 2.0);
+
    TH1F background1("background1", "background 1 histogram (pb)", 2, 1.0, 2.0);
    TH1F background2("background2", "background 2 histogram (pb)", 2, 1.0, 2.0);
    TH1F background1_statUncert("background1_statUncert", "statUncert", 2, 1.0, 2.0);
@@ -53,6 +40,12 @@ void createInputFile(std::string const &inputFileName)
    signal.SetBinContent(1, 20.);
    signal.SetBinContent(2, 10.);
 
+   systUncDo.SetBinContent(1, 15.);
+   systUncDo.SetBinContent(2, 8.);
+
+   systUncUp.SetBinContent(1, 29.);
+   systUncUp.SetBinContent(2, 13.);
+
    background1.SetBinContent(1, 100.);
    background1.SetBinContent(2, 0.);
 
@@ -66,13 +59,15 @@ void createInputFile(std::string const &inputFileName)
 
    data.Write();
    signal.Write();
+   systUncDo.Write();
+   systUncUp.Write();
    background1.Write();
    background2.Write();
    background1_statUncert.Write();
 }
 
 std::unique_ptr<RooStats::HistFactory::Measurement>
-measurement(std::string const &inputFileName = "test_hs3_histfactory_json_input.root")
+measurement(const char *inputFileName = "test_hs3_histfactory_json_input.root")
 {
    createInputFile(inputFileName);
 
@@ -88,17 +83,20 @@ measurement(std::string const &inputFileName = "test_hs3_histfactory_json_input.
    meas->SetBinHigh(2);
    // meas.AddConstantParam("syst1");
    RooStats::HistFactory::Channel chan{"channel1"};
-   chan.SetData("data", inputFileName.c_str());
+   chan.SetData("data", inputFileName);
    chan.SetStatErrorConfig(0.01, "Poisson");
-   RooStats::HistFactory::Sample sig{"signal", "signal", inputFileName.c_str()};
+
+   RooStats::HistFactory::Sample sig{"signal", "signal", inputFileName};
    sig.AddOverallSys("syst1", 0.95, 1.05);
    sig.AddNormFactor("mu", 1, -3, 5);
+   sig.AddHistoSys("SignalShape", "shapeUnc_sigDo", inputFileName, "", "shapeUnc_sigUp", inputFileName, "");
    chan.AddSample(sig);
-   RooStats::HistFactory::Sample background1{"background1", "background1", inputFileName.c_str()};
+
+   RooStats::HistFactory::Sample background1{"background1", "background1", inputFileName};
    background1.ActivateStatError("background1_statUncert", inputFileName);
    background1.AddOverallSys("syst2", 0.95, 1.05);
    chan.AddSample(background1);
-   RooStats::HistFactory::Sample background2{"background2", "background2", inputFileName.c_str()};
+   RooStats::HistFactory::Sample background2{"background2", "background2", inputFileName};
    background2.ActivateStatError();
    background2.AddOverallSys("syst3", 0.95, 1.05);
    chan.AddSample(background2);
@@ -112,8 +110,6 @@ measurement(std::string const &inputFileName = "test_hs3_histfactory_json_input.
 TEST(TestHS3HistFactoryJSON, Create)
 {
    RooHelpers::LocalChangeMsgLevel changeMsgLvl(RooFit::WARNING);
-
-   toJSON(*measurement(), "hf.json");
 }
 
 TEST(TestHS3HistFactoryJSON, Closure)
@@ -121,26 +117,32 @@ TEST(TestHS3HistFactoryJSON, Closure)
    RooHelpers::LocalChangeMsgLevel changeMsgLvl(RooFit::WARNING);
 
    std::unique_ptr<RooStats::HistFactory::Measurement> meas = measurement();
-   toJSON(*meas, "hf.json");
-   std::unique_ptr<RooWorkspace> ws = toWS(*meas);
-   std::unique_ptr<RooWorkspace> wsFromJson = importToWS("hf.json", "ws1");
+   if (writeJsonFiles) {
+      RooStats::HistFactory::JSONTool{*meas}.PrintJSON("hf.json");
+   }
+   std::stringstream ss;
+   RooStats::HistFactory::JSONTool{*meas}.PrintJSON(ss);
+
+   std::unique_ptr<RooWorkspace> ws{RooStats::HistFactory::MakeModelAndMeasurementFast(*meas)};
+   RooWorkspace wsFromJson{"ws1"};
+   RooJSONFactoryWSTool{wsFromJson}.importJSONfromString(ss.str());
 
    auto *mc = dynamic_cast<RooStats::ModelConfig *>(ws->obj("ModelConfig"));
    EXPECT_TRUE(mc != nullptr);
 
-   auto *mcFromJson = dynamic_cast<RooStats::ModelConfig *>(wsFromJson->obj("ModelConfig"));
+   auto *mcFromJson = dynamic_cast<RooStats::ModelConfig *>(wsFromJson.obj("ModelConfig"));
    EXPECT_TRUE(mcFromJson != nullptr);
 
    RooAbsPdf *pdf = mc->GetPdf();
    EXPECT_TRUE(pdf != nullptr);
 
-   RooAbsPdf *pdfFromJson = wsFromJson->pdf(meas->GetName());
+   RooAbsPdf *pdfFromJson = wsFromJson.pdf(meas->GetName());
    EXPECT_TRUE(pdfFromJson != nullptr);
 
    RooAbsData *data = ws->data("obsData");
    EXPECT_TRUE(data != nullptr);
 
-   RooAbsData *dataFromJson = wsFromJson->data("obsData");
+   RooAbsData *dataFromJson = wsFromJson.data("obsData");
    EXPECT_TRUE(dataFromJson != nullptr);
 
    using namespace RooFit;
@@ -156,11 +158,11 @@ TEST(TestHS3HistFactoryJSON, Closure)
    // TODO:
    //   * fix issues that prevent us from increasing precision
    //   * do also the reverse comparison to check that the set of constant parameters matches
-   // EXPECT_TRUE(result->isIdenticalNoCov(*resultFromJSON, 1.0, 0.01));
-   EXPECT_TRUE(resultFromJson->isIdenticalNoCov(*result, 1.0, 0.01));
+   // EXPECT_TRUE(result->isIdenticalNoCov(*resultFromJSON, 10.0, 0.01));
+   EXPECT_TRUE(resultFromJson->isIdenticalNoCov(*result, 10.0, 0.01));
 
    const double muVal = ws->var("mu")->getVal();
-   const double muJsonVal = wsFromJson->var("mu")->getVal();
+   const double muJsonVal = wsFromJson.var("mu")->getVal();
 
    EXPECT_NEAR(muJsonVal, muVal, 1e-4);         // absolute tolerance
    EXPECT_NEAR(muJsonVal, muVal, 1e-4 * muVal); // relative tolerance
@@ -171,7 +173,7 @@ TEST(TestHS3HistFactoryJSON, ClosureLoop)
    RooHelpers::LocalChangeMsgLevel changeMsgLvl(RooFit::WARNING);
 
    std::unique_ptr<RooStats::HistFactory::Measurement> meas = measurement();
-   std::unique_ptr<RooWorkspace> ws = toWS(*meas);
+   std::unique_ptr<RooWorkspace> ws{RooStats::HistFactory::MakeModelAndMeasurementFast(*meas)};
 
    auto *mc = dynamic_cast<RooStats::ModelConfig *>(ws->obj("ModelConfig"));
    EXPECT_TRUE(mc != nullptr);
@@ -183,11 +185,19 @@ TEST(TestHS3HistFactoryJSON, ClosureLoop)
    pdf->setStringAttribute("combined_data_name", "obsData");
 
    std::string const &js = RooJSONFactoryWSTool{*ws}.exportJSONtoString();
+   if (writeJsonFiles) {
+      RooJSONFactoryWSTool{*ws}.exportJSON("hf2.json");
+   }
 
    RooWorkspace newws("new");
    RooJSONFactoryWSTool newtool{newws};
    newtool.importJSONfromString(js);
 
+   // To check that JSON > WS > JSON doesn't change the JSON
+   if (writeJsonFiles) {
+      RooJSONFactoryWSTool{newws}.exportJSON("hf3.json");
+   }
+
    auto *newmc = dynamic_cast<RooStats::ModelConfig *>(newws.obj("ModelConfig"));
    EXPECT_TRUE(newmc != nullptr);
 
@@ -213,8 +223,8 @@ TEST(TestHS3HistFactoryJSON, ClosureLoop)
    // TODO:
    //   * fix issues that prevent us from increasing precision
    //   * do also the reverse comparison to check that the set of constant parameters matches
-   // EXPECT_TRUE(result->isIdenticalNoCov(*newresult, 1.0, 0.01));
-   EXPECT_TRUE(newresult->isIdenticalNoCov(*result, 1.0, 0.01));
+   // EXPECT_TRUE(result->isIdenticalNoCov(*newresult, 10.0, 0.01));
+   EXPECT_TRUE(newresult->isIdenticalNoCov(*result, 10.0, 0.01));
 
    const double muVal = ws->var("mu")->getVal();
    const double muNewVal = newws.var("mu")->getVal();
-- 
GitLab