diff --git a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h index 2321d6439b3fbf3c614a3f219e3b82ee65c2f527..4be3ae1eb4ab1fab697b5f9f48f3b6ce4d5518e3 100644 --- a/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h +++ b/roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h @@ -203,6 +203,9 @@ public: static void writeCombinedDataName(RooFit::Detail::JSONNode &rootnode, std::string const &pdfName, std::string const &dataName); + static void writeChannelNames(RooFit::Detail::JSONNode &rootnode, std::string const &simPdfName, + std::vector<std::string> const &channelNames); + private: struct Config { static bool stripObservables; diff --git a/roofit/hs3/src/HistFactoryJSONTool.cxx b/roofit/hs3/src/HistFactoryJSONTool.cxx index 649ccef3b18d29a53c5aae9c8ce66388055d9842..d1dd2230eade5d6e75cc03768d082e570a592260 100644 --- a/roofit/hs3/src/HistFactoryJSONTool.cxx +++ b/roofit/hs3/src/HistFactoryJSONTool.cxx @@ -221,6 +221,7 @@ void exportMeasurement(RooStats::HistFactory::Measurement &measurement, JSONNode } // the data + std::vector<std::string> channelNames; for (const auto &c : measurement.GetChannels()) { JSONNode &dataOutput = RooJSONFactoryWSTool::appendNamedChild(n["data"], std::string("obsData_") + c.GetName()); @@ -231,9 +232,11 @@ void exportMeasurement(RooStats::HistFactory::Measurement &measurement, JSONNode } RooJSONFactoryWSTool::exportHistogram(*c.GetData().GetHisto(), dataOutput, obsnames); + channelNames.push_back(c.GetName()); } RooJSONFactoryWSTool::writeCombinedDataName(n, measurement.GetName(), "obsData"); + RooJSONFactoryWSTool::writeChannelNames(n, measurement.GetName(), channelNames); } } // namespace diff --git a/roofit/hs3/src/RooJSONFactoryWSTool.cxx b/roofit/hs3/src/RooJSONFactoryWSTool.cxx index b3fad8785364aa7b1ab216b762a5156c3a75f771..a9bb7bd40743f9e7934583080577396e10756ac2 100644 --- a/roofit/hs3/src/RooJSONFactoryWSTool.cxx +++ b/roofit/hs3/src/RooJSONFactoryWSTool.cxx @@ -455,10 +455,12 @@ std::unique_ptr<RooAbsData> loadData(const JSONNode &p, RooWorkspace &workspace) return nullptr; } -void importAnalysis(const RooFit::Detail::JSONNode &analysisNode, const RooFit::Detail::JSONNode &likelihoodsNode, - const RooFit::Detail::JSONNode &mcAuxNode, RooWorkspace &workspace, +void importAnalysis(const RooFit::Detail::JSONNode &rootnode, const RooFit::Detail::JSONNode &analysisNode, + const RooFit::Detail::JSONNode &likelihoodsNode, RooWorkspace &workspace, std::vector<std::unique_ptr<RooAbsData>> &datas) { + const RooFit::Detail::JSONNode &mcAuxNode = + dereference(rootnode, {"misc", "ROOT_internal", "ModelConfigs", analysisNode["name"].val()}, rootnode); // if this is a toplevel pdf, also create a modelConfig for it std::string mcname = "ModelConfig"; @@ -481,10 +483,19 @@ void importAnalysis(const RooFit::Detail::JSONNode &analysisNode, const RooFit:: nllDataNames.push_back(nameNode.val()); } + std::string const &pdfName = analysisNode["name"].val(); + + std::vector<std::string> channelNames; + JSONNode const &channelNamesNode = + dereference(rootnode, {"misc", "ROOT_internal", "channel_names", pdfName}, rootnode); + for (auto &n : channelNamesNode.children()) { + channelNames.push_back(n.val()); + } + std::stringstream ss; - ss << "SIMUL::" << analysisNode["name"].val() << "(channelCat["; + ss << "SIMUL::" << pdfName << "(channelCat["; for (std::size_t iChannel = 0; iChannel < nllDistNames.size(); ++iChannel) { - ss << "channel_" << iChannel << "=" << iChannel; + ss << channelNames[iChannel] << "=" << iChannel; if (iChannel < nllDistNames.size() - 1) { ss << ","; } @@ -492,8 +503,7 @@ void importAnalysis(const RooFit::Detail::JSONNode &analysisNode, const RooFit:: ss << "]"; for (std::size_t iChannel = 0; iChannel < nllDistNames.size(); ++iChannel) { - ss << ", " - << "channel_" << iChannel << "=" << nllDistNames[iChannel]; + ss << ", " << channelNames[iChannel] << "=" << nllDistNames[iChannel]; } ss << ")"; auto pdf = static_cast<RooSimultaneous *>(workspace.factory(ss.str())); @@ -543,7 +553,7 @@ void importAnalysis(const RooFit::Detail::JSONNode &analysisNode, const RooFit:: std::unique_ptr<RooAbsData> &channelData = *std::find_if( datas.begin(), datas.end(), [&](auto &d) { return d && d->GetName() == nllDataNames[iChannel]; }); allVars.add(*channelData->get()); - dsMap.insert({"channel_" + std::to_string(iChannel), std::move(channelData)}); + dsMap.insert({channelNames[iChannel], std::move(channelData)}); } if (!mcAuxNode.has_child("combined_data_name")) { @@ -720,11 +730,35 @@ void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode & } } +void RooJSONFactoryWSTool::writeChannelNames(JSONNode &rootnode, std::string const &simPdfName, + std::vector<std::string> const &channelNames) +{ + auto &miscinfo = rootnode["misc"]; + miscinfo.set_map(); + auto &rootinfo = miscinfo["ROOT_internal"]; + rootinfo.set_map(); + auto &categoriesinfo = rootinfo["channel_names"]; + categoriesinfo.set_map(); + // Avoid repeated filling + if (!categoriesinfo.has_child(simPdfName)) { + auto &catinfo = categoriesinfo[simPdfName]; + catinfo.fill_seq(channelNames); + } +} + JSONNode *RooJSONFactoryWSTool::exportObject(const RooAbsArg *func) { - if (dynamic_cast<RooSimultaneous const *>(func)) { - // RooSimultaneous is not used in the HS3 standard, we only export the dependents + if (auto simPdf = dynamic_cast<RooSimultaneous const *>(func)) { + // RooSimultaneous is not used in the HS3 standard, we only export the + // dependents and some ROOT internal information. RooJSONFactoryWSTool::exportDependants(func); + + std::vector<std::string> channelNames; + for (auto const &item : simPdf->indexCat()) { + channelNames.push_back(item.first); + } + writeChannelNames(*_rootnodeOutput, simPdf->GetName(), channelNames); + return nullptr; } else if (dynamic_cast<RooAbsCategory const *>(func)) { // categories are created by the respective RooSimultaneous, so we're skipping the export here @@ -1425,18 +1459,16 @@ void RooJSONFactoryWSTool::importAllNodes(const RooFit::Detail::JSONNode &n) } } - _rootnodeInput = nullptr; - _domains.reset(); - // Now, read in analyses and likelihoods if there are any if (n.has_child("analyses")) { for (JSONNode const &analysisNode : n["analyses"].children()) { - importAnalysis(analysisNode, n["likelihoods"], - dereference(n, {"misc", "ROOT_internal", "ModelConfigs", analysisNode["name"].val()}, n), - _workspace, datas); + importAnalysis(*_rootnodeInput, analysisNode, n["likelihoods"], _workspace, datas); } } + _rootnodeInput = nullptr; + _domains.reset(); + for (auto const &d : datas) { if (d) _workspace.import(*d); diff --git a/tutorials/roofit/rf515_hfJSON.json b/tutorials/roofit/rf515_hfJSON.json index dc953901a47b36f544f998d5518d3acace0276ee..1509f76b961da6c370490ae2780b23e0013508a9 100644 --- a/tutorials/roofit/rf515_hfJSON.json +++ b/tutorials/roofit/rf515_hfJSON.json @@ -1,13 +1,18 @@ { - "misc" : { "ROOT_internal" : { "ModelConfigs" : { - "main": { - "combined_data_name": "observed" + "misc" : { "ROOT_internal" : { + "ModelConfigs" : { + "main": { + "combined_data_name": "observed" + } + }, + "channel_names" : { + "main": ["channel1"] } - }}}, + }}, "analyses": [ { "name": "main", - "likelihood": "channel1", + "likelihood": "main_likelihood", "pois": [ "mu" ], @@ -18,7 +23,7 @@ ], "likelihoods": [ { - "name": "channel1", + "name": "main_likelihood", "distributions": ["model_channel1"], "data": ["observed_channel1"] }