From 2824060d13bef0dee25070f61e10af7bfde4dbad Mon Sep 17 00:00:00 2001 From: Stefan Wunsch <stefan.wunsch@cern.ch> Date: Tue, 2 Jul 2019 11:11:14 +0200 Subject: [PATCH] [TMVA experimental] Add new reader interface --- tmva/tmva/CMakeLists.txt | 2 + tmva/tmva/inc/TMVA/RInferenceUtils.hxx | 39 ++++ tmva/tmva/inc/TMVA/RReader.hxx | 226 +++++++++++++++++++++ tmva/tmva/test/CMakeLists.txt | 3 + tmva/tmva/test/rreader.cxx | 268 +++++++++++++++++++++++++ tutorials/tmva/tmva003_RReader.C | 109 ++++++++++ 6 files changed, 647 insertions(+) create mode 100644 tmva/tmva/inc/TMVA/RInferenceUtils.hxx create mode 100644 tmva/tmva/inc/TMVA/RReader.hxx create mode 100644 tmva/tmva/test/rreader.cxx create mode 100644 tutorials/tmva/tmva003_RReader.C diff --git a/tmva/tmva/CMakeLists.txt b/tmva/tmva/CMakeLists.txt index 62007169475..7f42830acf0 100644 --- a/tmva/tmva/CMakeLists.txt +++ b/tmva/tmva/CMakeLists.txt @@ -25,6 +25,8 @@ if(dataframe) TMVA/RTensor.hxx TMVA/RTensorUtils.hxx TMVA/RStandardScaler.hxx + TMVA/RReader.hxx + TMVA/RInferenceUtils.hxx ) set(TMVA_EXTRA_DEPENDENCIES ROOTDataFrame diff --git a/tmva/tmva/inc/TMVA/RInferenceUtils.hxx b/tmva/tmva/inc/TMVA/RInferenceUtils.hxx new file mode 100644 index 00000000000..bd4f1dce4e7 --- /dev/null +++ b/tmva/tmva/inc/TMVA/RInferenceUtils.hxx @@ -0,0 +1,39 @@ +#ifndef TMVA_RINFERENCEUTILS +#define TMVA_RINFERENCEUTILS + +#include "ROOT/RIntegerSequence.hxx" // std::index_sequence +#include <utility> // std::forward + +namespace TMVA { +namespace Experimental { + +namespace Internal { + +/// Compute helper +template <typename I, typename T, typename F> +class ComputeHelper; + +template <std::size_t... N, typename T, typename F> +class ComputeHelper<std::index_sequence<N...>, T, F> { + template <std::size_t Idx> + using AlwaysT = T; + F fFunc; + +public: + ComputeHelper(F &&f) : fFunc(std::forward<F>(f)) {} + auto operator()(AlwaysT<N>... args) -> decltype(fFunc.Compute({args...})) { return fFunc.Compute({args...}); } +}; + +} // namespace Internal + +/// Helper to pass TMVA model to RDataFrame.Define nodes +template <std::size_t N, typename T, typename F> +auto Compute(F &&f) -> Internal::ComputeHelper<std::make_index_sequence<N>, T, F> +{ + return Internal::ComputeHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f)); +} + +} // namespace Experimental +} // namespace TMVA + +#endif // TMVA_RINFERENCEUTILS diff --git a/tmva/tmva/inc/TMVA/RReader.hxx b/tmva/tmva/inc/TMVA/RReader.hxx new file mode 100644 index 00000000000..474456ab3f9 --- /dev/null +++ b/tmva/tmva/inc/TMVA/RReader.hxx @@ -0,0 +1,226 @@ +#ifndef TMVA_RREADER +#define TMVA_RREADER + +#include "TString.h" +#include "TXMLEngine.h" +#include "ROOT/RMakeUnique.hxx" + +#include "TMVA/RTensor.hxx" +#include "TMVA/Reader.h" + +#include <memory> // std::unique_ptr +#include <sstream> // std::stringstream + +namespace TMVA { +namespace Experimental { + +namespace Internal { + +/// Internal definition of analysis types +enum AnalysisType : unsigned int { Undefined = 0, Classification, Regression, Multiclass }; + +/// Container for information extracted from TMVA XML config +struct XMLConfig { + unsigned int numVariables; + std::vector<std::string> variables; + unsigned int numClasses; + std::vector<std::string> classes; + AnalysisType analysisType; + XMLConfig() + : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)), + analysisType(Internal::AnalysisType::Undefined) + { + } +}; + +/// Parse TMVA XML config +inline XMLConfig ParseXMLConfig(const std::string &filename) +{ + XMLConfig c; + + // Parse XML file and find root node + TXMLEngine xml; + auto xmldoc = xml.ParseFile(filename.c_str()); + if (xmldoc == 0) { + std::stringstream ss; + ss << "Failed to open TMVA XML file " + << filename << "."; + throw std::runtime_error(ss.str()); + } + auto mainNode = xml.DocGetRootElement(xmldoc); + for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) { + const auto nodeName = std::string(xml.GetNodeName(node)); + // Read out input variables + if (nodeName.compare("Variables") == 0) { + c.numVariables = std::atoi(xml.GetAttr(node, "NVar")); + c.variables = std::vector<std::string>(c.numVariables); + for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) { + const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex")); + c.variables[iVariable] = xml.GetAttr(thisNode, "Title"); + } + } + // Read out output classes + else if (nodeName.compare("Classes") == 0) { + c.numClasses = std::atoi(xml.GetAttr(node, "NClass")); + for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) { + c.classes.push_back(xml.GetAttr(thisNode, "Name")); + } + } + // Read out analysis type + else if (nodeName.compare("GeneralInfo") == 0) { + std::string analysisType = ""; + for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) { + if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) { + analysisType = xml.GetAttr(thisNode, "value"); + } + } + if (analysisType.compare("Classification") == 0) { + c.analysisType = Internal::AnalysisType::Classification; + } else if (analysisType.compare("Regression") == 0) { + c.analysisType = Internal::AnalysisType::Regression; + } else if (analysisType.compare("Multiclass") == 0) { + c.analysisType = Internal::AnalysisType::Multiclass; + } + } + } + xml.FreeDoc(xmldoc); + + // Error-handling + if (c.numVariables != c.variables.size() || c.numVariables == 0) { + std::stringstream ss; + ss << "Failed to parse input variables from TMVA config " << filename << "."; + throw std::runtime_error(ss.str()); + } + if (c.numClasses != c.classes.size() || c.numClasses == 0) { + std::stringstream ss; + ss << "Failed to parse output classes from TMVA config " << filename << "."; + throw std::runtime_error(ss.str()); + } + if (c.analysisType == Internal::AnalysisType::Undefined) { + std::stringstream ss; + ss << "Failed to parse analysis type from TMVA config " << filename << "."; + throw std::runtime_error(ss.str()); + } + + return c; +} + +} // namespace Internal + +/// TMVA::Reader legacy interface +class RReader { +private: + std::unique_ptr<Reader> fReader; + std::vector<float> fValues; + std::vector<std::string> fVariables; + unsigned int fNumClasses; + const char *name = "RReader"; + Internal::AnalysisType fAnalysisType; + +public: + /// Create TMVA model from XML file + RReader(const std::string &path) + { + // Load config + auto c = Internal::ParseXMLConfig(path); + fVariables = c.variables; + fAnalysisType = c.analysisType; + fNumClasses = c.numClasses; + + // Setup reader + fReader = std::make_unique<Reader>("Silent"); + const auto numVars = fVariables.size(); + fValues = std::vector<float>(numVars); + for (std::size_t i = 0; i < numVars; i++) { + fReader->AddVariable(TString(fVariables[i]), &fValues[i]); + } + fReader->BookMVA(name, path.c_str()); + } + + /// Compute model prediction on vector + std::vector<float> Compute(const std::vector<float> &x) + { + if (x.size() != fVariables.size()) + throw std::runtime_error("Size of input vector is not equal to number of variables."); + + // Copy over inputs to memory used by TMVA reader + for (std::size_t i = 0; i < x.size(); i++) { + fValues[i] = x[i]; + } + + // Take lock to protect model evaluation + R__WRITE_LOCKGUARD(ROOT::gCoreMutex); + + // Evaluate TMVA model + // Classification + if (fAnalysisType == Internal::AnalysisType::Classification) { + return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))}); + } + // Regression + else if (fAnalysisType == Internal::AnalysisType::Regression) { + return fReader->EvaluateRegression(name); + } + // Multiclass + else if (fAnalysisType == Internal::AnalysisType::Multiclass) { + return fReader->EvaluateMulticlass(name); + } + // Throw error + else { + throw std::runtime_error("RReader has undefined analysis type."); + return std::vector<float>(); + } + } + + /// Compute model prediction on input RTensor + RTensor<float> Compute(RTensor<float> &x) + { + // Error-handling for input tensor + const auto shape = x.GetShape(); + if (shape.size() != 2) + throw std::runtime_error("Can only compute model outputs for input tensor of rank 2."); + + const auto numEntries = shape[0]; + const auto numVars = shape[1]; + if (numVars != fVariables.size()) + throw std::runtime_error("Second dimension of input tensor is not equal to number of variables."); + + // Define shape of output tensor based on analysis type + unsigned int numClasses = 1; + if (fAnalysisType == Internal::AnalysisType::Multiclass) + numClasses = fNumClasses; + RTensor<float> y({numEntries * numClasses}); + if (fAnalysisType == Internal::AnalysisType::Multiclass) + y = y.Reshape({numEntries, numClasses}); + + // Fill output tensor + for (std::size_t i = 0; i < numEntries; i++) { + for (std::size_t j = 0; j < numVars; j++) { + fValues[j] = x(i, j); + } + R__WRITE_LOCKGUARD(ROOT::gCoreMutex); + // Classification + if (fAnalysisType == Internal::AnalysisType::Classification) { + y(i) = fReader->EvaluateMVA(name); + } + // Regression + else if (fAnalysisType == Internal::AnalysisType::Regression) { + y(i) = fReader->EvaluateRegression(name)[0]; + } + // Multiclass + else if (fAnalysisType == Internal::AnalysisType::Multiclass) { + const auto p = fReader->EvaluateMulticlass(name); + for (std::size_t k = 0; k < numClasses; k++) + y(i, k) = p[k]; + } + } + + return y; + } + + std::vector<std::string> GetVariableNames() { return fVariables; } +}; + +} // namespace Experimental +} // namespace TMVA + +#endif // TMVA_RREADER diff --git a/tmva/tmva/test/CMakeLists.txt b/tmva/tmva/test/CMakeLists.txt index 20cd441d495..835ff5fb47b 100644 --- a/tmva/tmva/test/CMakeLists.txt +++ b/tmva/tmva/test/CMakeLists.txt @@ -18,11 +18,14 @@ ROOT_ADD_GTEST(TestRandomGenerator LIBRARIES ${Libraries}) if(dataframe) + # RTensor ROOT_ADD_GTEST(rtensor rtensor.cxx LIBRARIES ROOTVecOps TMVA) ROOT_ADD_GTEST(rtensor-iterator rtensor_iterator.cxx LIBRARIES ROOTVecOps TMVA) ROOT_ADD_GTEST(rtensor-utils rtensor_utils.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame) # RStandardScaler ROOT_ADD_GTEST(rstandardscaler rstandardscaler.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame) + # RReader + ROOT_ADD_GTEST(rreader rreader.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame) endif() project(tmva-tests) diff --git a/tmva/tmva/test/rreader.cxx b/tmva/tmva/test/rreader.cxx new file mode 100644 index 00000000000..c90cd6014b7 --- /dev/null +++ b/tmva/tmva/test/rreader.cxx @@ -0,0 +1,268 @@ +#include <gtest/gtest.h> + +#include <TFile.h> +#include <TTree.h> +#include <TSystem.h> +#include <TMVA/Factory.h> +#include <TMVA/DataLoader.h> + +#include <TMVA/RReader.hxx> +#include <TMVA/RInferenceUtils.hxx> +#include <TMVA/RTensor.hxx> +#include <TMVA/RTensorUtils.hxx> + +using namespace TMVA::Experimental; + +// Classification +static const std::string modelClassification = "RReaderClassification/weights/RReaderClassification_BDT.weights.xml"; +static const std::string filenameClassification = "http://root.cern.ch/files/tmva_class_example.root"; +static const std::vector<std::string> variablesClassification = {"var1", "var2", "var3", "var4"}; + +void TrainClassificationModel() +{ + // Check for existing training + if (gSystem->mkdir("RReaderClassification") == -1) return; + + // Create factory + auto output = TFile::Open("TMVA.root", "RECREATE"); + auto factory = new TMVA::Factory("RReaderClassification", + output, "Silent:!V:!DrawProgressBar:AnalysisType=Classification"); + + // Open trees with signal and background events + const std::string filename = "http://root.cern.ch/files/tmva_class_example.root"; + auto data = TFile::Open(filename.c_str()); + auto signal = (TTree *)data->Get("TreeS"); + auto background = (TTree *)data->Get("TreeB"); + + // Add variables and register the trees with the dataloader + auto dataloader = new TMVA::DataLoader("RReaderClassification"); + const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"}; + for (const auto &var : variables) { + dataloader->AddVariable(var); + } + dataloader->AddSignalTree(signal, 1.0); + dataloader->AddBackgroundTree(background, 1.0); + dataloader->PrepareTrainingAndTestTree("", ""); + + // Train a TMVA method + factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=100:MaxDepth=2"); + factory->TrainAllMethods(); + output->Close(); +} + +// Regression +static const std::string modelRegression = "RReaderRegression/weights/RReaderRegression_BDTG.weights.xml"; +static const std::string filenameRegression = "http://root.cern.ch/files/tmva_reg_example.root"; +static const std::vector<std::string> variablesRegression = {"var1", "var2"}; + +void TrainRegressionModel() +{ + // Check for existing training + if (gSystem->mkdir("RReaderRegression") == -1) return; + + // Create factory + auto output = TFile::Open("TMVA.root", "RECREATE"); + auto factory = new TMVA::Factory("RReaderRegression", + output, "Silent:!V:!DrawProgressBar:AnalysisType=Regression"); + + // Open trees with signal and background events + const std::string filename = "http://root.cern.ch/files/tmva_reg_example.root"; + auto data = TFile::Open(filename.c_str()); + auto tree = (TTree *)data->Get("TreeR"); + + // Add variables and register the trees with the dataloader + auto dataloader = new TMVA::DataLoader("RReaderRegression"); + dataloader->AddVariable("var1"); + dataloader->AddVariable("var2"); + dataloader->AddTarget("fvalue"); + dataloader->AddRegressionTree(tree, 1.0); + dataloader->PrepareTrainingAndTestTree("", ""); + + // Train a TMVA method + factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTG", "!V:!H:NTrees=100:MaxDepth=2"); + factory->TrainAllMethods(); + output->Close(); +} + +// Multiclass +static const std::string modelMulticlass = "RReaderMulticlass/weights/RReaderMulticlass_BDT.weights.xml"; +static const std::string filenameMulticlass = "http://root.cern.ch/files/tmva_multiclass_example.root"; +static const std::vector<std::string> variablesMulticlass = variablesClassification; + +void TrainMulticlassModel() +{ + // Check for existing training + if (gSystem->mkdir("RReaderMulticlass") == -1) return; + + // Create factory + auto output = TFile::Open("TMVA.root", "RECREATE"); + auto factory = new TMVA::Factory("RReaderMulticlass", + output, "Silent:!V:!DrawProgressBar:AnalysisType=Multiclass"); + + // Open trees with signal and background events + const std::string filename = "http://root.cern.ch/files/tmva_multiclass_example.root"; + auto data = TFile::Open(filename.c_str()); + auto signal = (TTree *)data->Get("TreeS"); + auto background0 = (TTree *)data->Get("TreeB0"); + auto background1 = (TTree *)data->Get("TreeB1"); + auto background2 = (TTree *)data->Get("TreeB2"); + + // Add variables and register the trees with the dataloader + auto dataloader = new TMVA::DataLoader("RReaderMulticlass"); + const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"}; + for (const auto &var : variables) { + dataloader->AddVariable(var); + } + dataloader->AddTree(signal, "Signal"); + dataloader->AddTree(background0, "Background_0"); + dataloader->AddTree(background1, "Background_1"); + dataloader->AddTree(background2, "Background_2"); + dataloader->PrepareTrainingAndTestTree("", ""); + + // Train a TMVA method + factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=100:MaxDepth=2:BoostType=Grad"); + factory->TrainAllMethods(); + output->Close(); +} + +TEST(RReader, ClassificationGetVariables) +{ + TrainClassificationModel(); + RReader model(modelClassification); + auto vars = model.GetVariableNames(); + EXPECT_EQ(vars.size(), 4ul); + for (std::size_t i = 0; i < vars.size(); i++) { + EXPECT_EQ(vars[i], variablesClassification[i]); + } +} + +TEST(RReader, ClassificationComputeVector) +{ + TrainClassificationModel(); + const std::vector<float> x = {1.0, 2.0, 3.0, 4.0}; + RReader model(modelClassification); + auto y = model.Compute(x); + EXPECT_EQ(y.size(), 1ul); +} + +TEST(RReader, ClassificationComputeTensor) +{ + TrainClassificationModel(); + ROOT::RDataFrame df("TreeS", filenameClassification); + auto x = AsTensor<float>(df, variablesClassification); + + RReader model(modelClassification); + auto y = model.Compute(x); + + const auto shapeX = x.GetShape(); + const auto shapeY = y.GetShape(); + EXPECT_EQ(shapeY.size(), 1ul); + EXPECT_EQ(shapeY[0], shapeX[0]); +} + +TEST(RReader, ClassificationComputeDataFrame) +{ + TrainClassificationModel(); + ROOT::RDataFrame df("TreeS", filenameClassification); + RReader model(modelClassification); + auto df2 = df.Define("y", Compute<4, float>(model), variablesClassification); + auto df3 = df2.Filter("y.size() == 1"); + auto c = df3.Count(); + auto y = df2.Take<std::vector<float>>("y"); + EXPECT_EQ(y->size(), *c); +} + +TEST(RReader, RegressionGetVariables) +{ + TrainRegressionModel(); + RReader model(modelRegression); + auto vars = model.GetVariableNames(); + EXPECT_EQ(vars.size(), 2ul); + for (std::size_t i = 0; i < vars.size(); i++) { + EXPECT_EQ(vars[i], variablesRegression[i]); + } +} + +TEST(RReader, RegressionComputeVector) +{ + TrainRegressionModel(); + const std::vector<float> x = {1.0, 2.0}; + RReader model(modelRegression); + auto y = model.Compute(x); + EXPECT_EQ(y.size(), 1ul); +} + +TEST(RReader, RegressionComputeTensor) +{ + TrainRegressionModel(); + ROOT::RDataFrame df("TreeR", filenameRegression); + auto x = AsTensor<float>(df, variablesRegression); + + RReader model(modelRegression); + auto y = model.Compute(x); + + const auto shapeX = x.GetShape(); + const auto shapeY = y.GetShape(); + EXPECT_EQ(shapeY.size(), 1ul); + EXPECT_EQ(shapeY[0], shapeX[0]); +} + +TEST(RReader, RegressionComputeDataFrame) +{ + TrainRegressionModel(); + ROOT::RDataFrame df("TreeR", filenameRegression); + RReader model(modelRegression); + auto df2 = df.Define("y", Compute<2, float>(model), variablesRegression); + auto df3 = df2.Filter("y.size() == 1"); + auto c = df3.Count(); + auto y = df2.Take<std::vector<float>>("y"); + EXPECT_EQ(y->size(), *c); +} + +TEST(RReader, MulticlassGetVariables) +{ + TrainMulticlassModel(); + RReader model(modelMulticlass); + auto vars = model.GetVariableNames(); + EXPECT_EQ(vars.size(), 4ul); + for (std::size_t i = 0; i < vars.size(); i++) { + EXPECT_EQ(vars[i], variablesMulticlass[i]); + } +} + +TEST(RReader, MulticlassComputeVector) +{ + TrainMulticlassModel(); + const std::vector<float> x = {1.0, 2.0, 3.0, 4.0}; + RReader model(modelMulticlass); + auto y = model.Compute(x); + EXPECT_EQ(y.size(), 4ul); +} + +TEST(RReader, MulticlassComputeTensor) +{ + TrainMulticlassModel(); + ROOT::RDataFrame df("TreeS", filenameMulticlass); + auto x = AsTensor<float>(df, variablesMulticlass); + + RReader model(modelMulticlass); + auto y = model.Compute(x); + + const auto shapeX = x.GetShape(); + const auto shapeY = y.GetShape(); + EXPECT_EQ(shapeY.size(), 2ul); + EXPECT_EQ(shapeY[0], shapeX[0]); + EXPECT_EQ(shapeY[1], 4ul); +} + +TEST(RReader, MulticlassComputeDataFrame) +{ + TrainMulticlassModel(); + ROOT::RDataFrame df("TreeS", filenameMulticlass); + RReader model(modelMulticlass); + auto df2 = df.Define("y", Compute<4, float>(model), variablesMulticlass); + auto df3 = df2.Filter("y.size() == 4"); + auto c = df3.Count(); + auto y = df2.Take<std::vector<float>>("y"); + EXPECT_EQ(y->size(), *c); +} diff --git a/tutorials/tmva/tmva003_RReader.C b/tutorials/tmva/tmva003_RReader.C new file mode 100644 index 00000000000..2eb9a7c3f86 --- /dev/null +++ b/tutorials/tmva/tmva003_RReader.C @@ -0,0 +1,109 @@ +/// \file +/// \ingroup tutorial_tmva +/// \notebook -nodraw +/// This tutorial shows how to apply with the modern interfaces models saved in +/// TMVA XML files. +/// +/// \macro_code +/// \macro_output +/// +/// \date July 2019 +/// \author Stefan Wunsch + +using namespace TMVA::Experimental; + +void train(const std::string &filename) +{ + // Create factory + auto output = TFile::Open("TMVA.root", "RECREATE"); + auto factory = new TMVA::Factory("tmva003", + output, "!V:!DrawProgressBar:AnalysisType=Classification"); + + // Open trees with signal and background events + auto data = TFile::Open(filename.c_str()); + auto signal = (TTree *)data->Get("TreeS"); + auto background = (TTree *)data->Get("TreeB"); + + // Add variables and register the trees with the dataloader + auto dataloader = new TMVA::DataLoader("tmva003_BDT"); + const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"}; + for (const auto &var : variables) { + dataloader->AddVariable(var); + } + dataloader->AddSignalTree(signal, 1.0); + dataloader->AddBackgroundTree(background, 1.0); + dataloader->PrepareTrainingAndTestTree("", ""); + + // Train a TMVA method + factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=300:MaxDepth=2"); + factory->TrainAllMethods(); +} + +void tmva003_RReader() +{ + // First, let's train a model with TMVA. + const std::string filename = "http://root.cern.ch/files/tmva_class_example.root"; + train(filename); + + // Next, we load the model from the TMVA XML file. + RReader model("tmva003_BDT/weights/tmva003_BDT.weights.xml"); + + // In case you need a reminder of the names and order of the variables during + // training, you can ask the model for it. + auto variables = model.GetVariableNames(); + + // The model can now be applied in different scenarios: + // 1) Event-by-event inference + // 2) Batch inference on data of multiple events + // 3) Inference as part of an RDataFrame graph + + // 1) Event-by-event inference + // The event-by-event inference takes the values of the variables as a std::vector<float>. + // Note that the return value is as well a std::vector<float> since the reader + // is also capable to process models with multiple outputs. + auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5}); + std::cout << "Single-event inference: " << prediction[0] << "\n\n"; + + // 2) Batch inference on data of multiple events + // For batch inference, the data needs to be structured as a matrix. For this + // purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame + // and the AsTensor utility to make the read-out from the ROOT file. + ROOT::RDataFrame df("TreeS", filename); + auto df2 = df.Range(3); // Read only a small subset of the dataset + auto x = AsTensor<float>(df2, variables); + auto y = model.Compute(x); + + std::cout << "RTensor input for inference on data of multiple events:\n" << x << "\n\n"; + std::cout << "Prediction performed on multiple events: " << y << "\n\n"; + + // 3) Perform inference as part of an RDataFrame graph + // We write a small lambda function that performs for us the inference on + // a dataframe to omit code duplication. + auto make_histo = [&](const std::string &treename) { + ROOT::RDataFrame df(treename, filename); + auto df2 = df.Define("y", Compute<4, float>(model), variables); + return df2.Histo1D({treename.c_str(), ";BDT score;N_{Events}", 30, -0.5, 0.5}, "y"); + }; + + auto sig = make_histo("TreeS"); + auto bkg = make_histo("TreeB"); + + // Make plot + gStyle->SetOptStat(0); + auto c = new TCanvas("", "", 800, 800); + + sig->SetLineColor(kRed); + bkg->SetLineColor(kBlue); + sig->SetLineWidth(2); + bkg->SetLineWidth(2); + bkg->Draw("HIST"); + sig->Draw("HIST SAME"); + + TLegend legend(0.7, 0.7, 0.89, 0.89); + legend.SetBorderSize(0); + legend.AddEntry("TreeS", "Signal", "l"); + legend.AddEntry("TreeB", "Background", "l"); + legend.Draw(); + + c->DrawClone(); +} -- GitLab