From 2824060d13bef0dee25070f61e10af7bfde4dbad Mon Sep 17 00:00:00 2001
From: Stefan Wunsch <stefan.wunsch@cern.ch>
Date: Tue, 2 Jul 2019 11:11:14 +0200
Subject: [PATCH] [TMVA experimental] Add new reader interface

---
 tmva/tmva/CMakeLists.txt               |   2 +
 tmva/tmva/inc/TMVA/RInferenceUtils.hxx |  39 ++++
 tmva/tmva/inc/TMVA/RReader.hxx         | 226 +++++++++++++++++++++
 tmva/tmva/test/CMakeLists.txt          |   3 +
 tmva/tmva/test/rreader.cxx             | 268 +++++++++++++++++++++++++
 tutorials/tmva/tmva003_RReader.C       | 109 ++++++++++
 6 files changed, 647 insertions(+)
 create mode 100644 tmva/tmva/inc/TMVA/RInferenceUtils.hxx
 create mode 100644 tmva/tmva/inc/TMVA/RReader.hxx
 create mode 100644 tmva/tmva/test/rreader.cxx
 create mode 100644 tutorials/tmva/tmva003_RReader.C

diff --git a/tmva/tmva/CMakeLists.txt b/tmva/tmva/CMakeLists.txt
index 62007169475..7f42830acf0 100644
--- a/tmva/tmva/CMakeLists.txt
+++ b/tmva/tmva/CMakeLists.txt
@@ -25,6 +25,8 @@ if(dataframe)
         TMVA/RTensor.hxx
         TMVA/RTensorUtils.hxx
         TMVA/RStandardScaler.hxx
+        TMVA/RReader.hxx
+        TMVA/RInferenceUtils.hxx
     )
     set(TMVA_EXTRA_DEPENDENCIES
         ROOTDataFrame
diff --git a/tmva/tmva/inc/TMVA/RInferenceUtils.hxx b/tmva/tmva/inc/TMVA/RInferenceUtils.hxx
new file mode 100644
index 00000000000..bd4f1dce4e7
--- /dev/null
+++ b/tmva/tmva/inc/TMVA/RInferenceUtils.hxx
@@ -0,0 +1,39 @@
+#ifndef TMVA_RINFERENCEUTILS
+#define TMVA_RINFERENCEUTILS
+
+#include "ROOT/RIntegerSequence.hxx" // std::index_sequence
+#include <utility> // std::forward
+
+namespace TMVA {
+namespace Experimental {
+
+namespace Internal {
+
+/// Compute helper
+template <typename I, typename T, typename F>
+class ComputeHelper;
+
+template <std::size_t... N, typename T, typename F>
+class ComputeHelper<std::index_sequence<N...>, T, F> {
+   template <std::size_t Idx>
+   using AlwaysT = T;
+   F fFunc;
+
+public:
+   ComputeHelper(F &&f) : fFunc(std::forward<F>(f)) {}
+   auto operator()(AlwaysT<N>... args) -> decltype(fFunc.Compute({args...})) { return fFunc.Compute({args...}); }
+};
+
+} // namespace Internal
+
+/// Helper to pass TMVA model to RDataFrame.Define nodes
+template <std::size_t N, typename T, typename F>
+auto Compute(F &&f) -> Internal::ComputeHelper<std::make_index_sequence<N>, T, F>
+{
+   return Internal::ComputeHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f));
+}
+
+} // namespace Experimental
+} // namespace TMVA
+
+#endif // TMVA_RINFERENCEUTILS
diff --git a/tmva/tmva/inc/TMVA/RReader.hxx b/tmva/tmva/inc/TMVA/RReader.hxx
new file mode 100644
index 00000000000..474456ab3f9
--- /dev/null
+++ b/tmva/tmva/inc/TMVA/RReader.hxx
@@ -0,0 +1,226 @@
+#ifndef TMVA_RREADER
+#define TMVA_RREADER
+
+#include "TString.h"
+#include "TXMLEngine.h"
+#include "ROOT/RMakeUnique.hxx"
+
+#include "TMVA/RTensor.hxx"
+#include "TMVA/Reader.h"
+
+#include <memory> // std::unique_ptr
+#include <sstream> // std::stringstream
+
+namespace TMVA {
+namespace Experimental {
+
+namespace Internal {
+
+/// Internal definition of analysis types
+enum AnalysisType : unsigned int { Undefined = 0, Classification, Regression, Multiclass };
+
+/// Container for information extracted from TMVA XML config
+struct XMLConfig {
+   unsigned int numVariables;
+   std::vector<std::string> variables;
+   unsigned int numClasses;
+   std::vector<std::string> classes;
+   AnalysisType analysisType;
+   XMLConfig()
+      : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
+        analysisType(Internal::AnalysisType::Undefined)
+   {
+   }
+};
+
+/// Parse TMVA XML config
+inline XMLConfig ParseXMLConfig(const std::string &filename)
+{
+   XMLConfig c;
+
+   // Parse XML file and find root node
+   TXMLEngine xml;
+   auto xmldoc = xml.ParseFile(filename.c_str());
+   if (xmldoc == 0) {
+      std::stringstream ss;
+      ss << "Failed to open TMVA XML file "
+         << filename << ".";
+      throw std::runtime_error(ss.str());
+   }
+   auto mainNode = xml.DocGetRootElement(xmldoc);
+   for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
+      const auto nodeName = std::string(xml.GetNodeName(node));
+      // Read out input variables
+      if (nodeName.compare("Variables") == 0) {
+         c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
+         c.variables = std::vector<std::string>(c.numVariables);
+         for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
+            const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
+            c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
+         }
+      }
+      // Read out output classes
+      else if (nodeName.compare("Classes") == 0) {
+         c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
+         for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
+            c.classes.push_back(xml.GetAttr(thisNode, "Name"));
+         }
+      }
+      // Read out analysis type
+      else if (nodeName.compare("GeneralInfo") == 0) {
+         std::string analysisType = "";
+         for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
+            if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
+               analysisType = xml.GetAttr(thisNode, "value");
+            }
+         }
+         if (analysisType.compare("Classification") == 0) {
+            c.analysisType = Internal::AnalysisType::Classification;
+         } else if (analysisType.compare("Regression") == 0) {
+            c.analysisType = Internal::AnalysisType::Regression;
+         } else if (analysisType.compare("Multiclass") == 0) {
+            c.analysisType = Internal::AnalysisType::Multiclass;
+         }
+      }
+   }
+   xml.FreeDoc(xmldoc);
+
+   // Error-handling
+   if (c.numVariables != c.variables.size() || c.numVariables == 0) {
+      std::stringstream ss;
+      ss << "Failed to parse input variables from TMVA config " << filename << ".";
+      throw std::runtime_error(ss.str());
+   }
+   if (c.numClasses != c.classes.size() || c.numClasses == 0) {
+      std::stringstream ss;
+      ss << "Failed to parse output classes from TMVA config " << filename << ".";
+      throw std::runtime_error(ss.str());
+   }
+   if (c.analysisType == Internal::AnalysisType::Undefined) {
+      std::stringstream ss;
+      ss << "Failed to parse analysis type from TMVA config " << filename << ".";
+      throw std::runtime_error(ss.str());
+   }
+
+   return c;
+}
+
+} // namespace Internal
+
+/// TMVA::Reader legacy interface
+class RReader {
+private:
+   std::unique_ptr<Reader> fReader;
+   std::vector<float> fValues;
+   std::vector<std::string> fVariables;
+   unsigned int fNumClasses;
+   const char *name = "RReader";
+   Internal::AnalysisType fAnalysisType;
+
+public:
+   /// Create TMVA model from XML file
+   RReader(const std::string &path)
+   {
+      // Load config
+      auto c = Internal::ParseXMLConfig(path);
+      fVariables = c.variables;
+      fAnalysisType = c.analysisType;
+      fNumClasses = c.numClasses;
+
+      // Setup reader
+      fReader = std::make_unique<Reader>("Silent");
+      const auto numVars = fVariables.size();
+      fValues = std::vector<float>(numVars);
+      for (std::size_t i = 0; i < numVars; i++) {
+         fReader->AddVariable(TString(fVariables[i]), &fValues[i]);
+      }
+      fReader->BookMVA(name, path.c_str());
+   }
+
+   /// Compute model prediction on vector
+   std::vector<float> Compute(const std::vector<float> &x)
+   {
+      if (x.size() != fVariables.size())
+         throw std::runtime_error("Size of input vector is not equal to number of variables.");
+
+      // Copy over inputs to memory used by TMVA reader
+      for (std::size_t i = 0; i < x.size(); i++) {
+         fValues[i] = x[i];
+      }
+
+      // Take lock to protect model evaluation
+      R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
+
+      // Evaluate TMVA model
+      // Classification
+      if (fAnalysisType == Internal::AnalysisType::Classification) {
+         return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
+      }
+      // Regression
+      else if (fAnalysisType == Internal::AnalysisType::Regression) {
+         return fReader->EvaluateRegression(name);
+      }
+      // Multiclass
+      else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
+         return fReader->EvaluateMulticlass(name);
+      }
+      // Throw error
+      else {
+         throw std::runtime_error("RReader has undefined analysis type.");
+         return std::vector<float>();
+      }
+   }
+
+   /// Compute model prediction on input RTensor
+   RTensor<float> Compute(RTensor<float> &x)
+   {
+      // Error-handling for input tensor
+      const auto shape = x.GetShape();
+      if (shape.size() != 2)
+         throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
+
+      const auto numEntries = shape[0];
+      const auto numVars = shape[1];
+      if (numVars != fVariables.size())
+         throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
+
+      // Define shape of output tensor based on analysis type
+      unsigned int numClasses = 1;
+      if (fAnalysisType == Internal::AnalysisType::Multiclass)
+         numClasses = fNumClasses;
+      RTensor<float> y({numEntries * numClasses});
+      if (fAnalysisType == Internal::AnalysisType::Multiclass)
+         y = y.Reshape({numEntries, numClasses});
+
+      // Fill output tensor
+      for (std::size_t i = 0; i < numEntries; i++) {
+         for (std::size_t j = 0; j < numVars; j++) {
+            fValues[j] = x(i, j);
+         }
+         R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
+         // Classification
+         if (fAnalysisType == Internal::AnalysisType::Classification) {
+            y(i) = fReader->EvaluateMVA(name);
+         }
+         // Regression
+         else if (fAnalysisType == Internal::AnalysisType::Regression) {
+            y(i) = fReader->EvaluateRegression(name)[0];
+         }
+         // Multiclass
+         else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
+            const auto p = fReader->EvaluateMulticlass(name);
+            for (std::size_t k = 0; k < numClasses; k++)
+               y(i, k) = p[k];
+         }
+      }
+
+      return y;
+   }
+
+   std::vector<std::string> GetVariableNames() { return fVariables; }
+};
+
+} // namespace Experimental
+} // namespace TMVA
+
+#endif // TMVA_RREADER
diff --git a/tmva/tmva/test/CMakeLists.txt b/tmva/tmva/test/CMakeLists.txt
index 20cd441d495..835ff5fb47b 100644
--- a/tmva/tmva/test/CMakeLists.txt
+++ b/tmva/tmva/test/CMakeLists.txt
@@ -18,11 +18,14 @@ ROOT_ADD_GTEST(TestRandomGenerator
                LIBRARIES ${Libraries})
 
 if(dataframe)
+    # RTensor
     ROOT_ADD_GTEST(rtensor rtensor.cxx LIBRARIES ROOTVecOps TMVA)
     ROOT_ADD_GTEST(rtensor-iterator rtensor_iterator.cxx LIBRARIES ROOTVecOps TMVA)
     ROOT_ADD_GTEST(rtensor-utils rtensor_utils.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame)
     # RStandardScaler
     ROOT_ADD_GTEST(rstandardscaler rstandardscaler.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame)
+    # RReader
+    ROOT_ADD_GTEST(rreader rreader.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame)
 endif()
 
 project(tmva-tests)
diff --git a/tmva/tmva/test/rreader.cxx b/tmva/tmva/test/rreader.cxx
new file mode 100644
index 00000000000..c90cd6014b7
--- /dev/null
+++ b/tmva/tmva/test/rreader.cxx
@@ -0,0 +1,268 @@
+#include <gtest/gtest.h>
+
+#include <TFile.h>
+#include <TTree.h>
+#include <TSystem.h>
+#include <TMVA/Factory.h>
+#include <TMVA/DataLoader.h>
+
+#include <TMVA/RReader.hxx>
+#include <TMVA/RInferenceUtils.hxx>
+#include <TMVA/RTensor.hxx>
+#include <TMVA/RTensorUtils.hxx>
+
+using namespace TMVA::Experimental;
+
+// Classification
+static const std::string modelClassification = "RReaderClassification/weights/RReaderClassification_BDT.weights.xml";
+static const std::string filenameClassification = "http://root.cern.ch/files/tmva_class_example.root";
+static const std::vector<std::string> variablesClassification = {"var1", "var2", "var3", "var4"};
+
+void TrainClassificationModel()
+{
+   // Check for existing training
+   if (gSystem->mkdir("RReaderClassification") == -1) return;
+
+   // Create factory
+   auto output = TFile::Open("TMVA.root", "RECREATE");
+   auto factory = new TMVA::Factory("RReaderClassification",
+           output, "Silent:!V:!DrawProgressBar:AnalysisType=Classification");
+
+   // Open trees with signal and background events
+   const std::string filename = "http://root.cern.ch/files/tmva_class_example.root";
+   auto data = TFile::Open(filename.c_str());
+   auto signal = (TTree *)data->Get("TreeS");
+   auto background = (TTree *)data->Get("TreeB");
+
+   // Add variables and register the trees with the dataloader
+   auto dataloader = new TMVA::DataLoader("RReaderClassification");
+   const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
+   for (const auto &var : variables) {
+      dataloader->AddVariable(var);
+   }
+   dataloader->AddSignalTree(signal, 1.0);
+   dataloader->AddBackgroundTree(background, 1.0);
+   dataloader->PrepareTrainingAndTestTree("", "");
+
+   // Train a TMVA method
+   factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=100:MaxDepth=2");
+   factory->TrainAllMethods();
+   output->Close();
+}
+
+// Regression
+static const std::string modelRegression = "RReaderRegression/weights/RReaderRegression_BDTG.weights.xml";
+static const std::string filenameRegression = "http://root.cern.ch/files/tmva_reg_example.root";
+static const std::vector<std::string> variablesRegression = {"var1", "var2"};
+
+void TrainRegressionModel()
+{
+   // Check for existing training
+   if (gSystem->mkdir("RReaderRegression") == -1) return;
+
+   // Create factory
+   auto output = TFile::Open("TMVA.root", "RECREATE");
+   auto factory = new TMVA::Factory("RReaderRegression",
+           output, "Silent:!V:!DrawProgressBar:AnalysisType=Regression");
+
+   // Open trees with signal and background events
+   const std::string filename = "http://root.cern.ch/files/tmva_reg_example.root";
+   auto data = TFile::Open(filename.c_str());
+   auto tree = (TTree *)data->Get("TreeR");
+
+   // Add variables and register the trees with the dataloader
+   auto dataloader = new TMVA::DataLoader("RReaderRegression");
+   dataloader->AddVariable("var1");
+   dataloader->AddVariable("var2");
+   dataloader->AddTarget("fvalue");
+   dataloader->AddRegressionTree(tree, 1.0);
+   dataloader->PrepareTrainingAndTestTree("", "");
+
+   // Train a TMVA method
+   factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTG", "!V:!H:NTrees=100:MaxDepth=2");
+   factory->TrainAllMethods();
+   output->Close();
+}
+
+// Multiclass
+static const std::string modelMulticlass = "RReaderMulticlass/weights/RReaderMulticlass_BDT.weights.xml";
+static const std::string filenameMulticlass = "http://root.cern.ch/files/tmva_multiclass_example.root";
+static const std::vector<std::string> variablesMulticlass = variablesClassification;
+
+void TrainMulticlassModel()
+{
+   // Check for existing training
+   if (gSystem->mkdir("RReaderMulticlass") == -1) return;
+
+   // Create factory
+   auto output = TFile::Open("TMVA.root", "RECREATE");
+   auto factory = new TMVA::Factory("RReaderMulticlass",
+           output, "Silent:!V:!DrawProgressBar:AnalysisType=Multiclass");
+
+   // Open trees with signal and background events
+   const std::string filename = "http://root.cern.ch/files/tmva_multiclass_example.root";
+   auto data = TFile::Open(filename.c_str());
+   auto signal = (TTree *)data->Get("TreeS");
+   auto background0 = (TTree *)data->Get("TreeB0");
+   auto background1 = (TTree *)data->Get("TreeB1");
+   auto background2 = (TTree *)data->Get("TreeB2");
+
+   // Add variables and register the trees with the dataloader
+   auto dataloader = new TMVA::DataLoader("RReaderMulticlass");
+   const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
+   for (const auto &var : variables) {
+      dataloader->AddVariable(var);
+   }
+   dataloader->AddTree(signal, "Signal");
+   dataloader->AddTree(background0, "Background_0");
+   dataloader->AddTree(background1, "Background_1");
+   dataloader->AddTree(background2, "Background_2");
+   dataloader->PrepareTrainingAndTestTree("", "");
+
+   // Train a TMVA method
+   factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=100:MaxDepth=2:BoostType=Grad");
+   factory->TrainAllMethods();
+   output->Close();
+}
+
+TEST(RReader, ClassificationGetVariables)
+{
+   TrainClassificationModel();
+   RReader model(modelClassification);
+   auto vars = model.GetVariableNames();
+   EXPECT_EQ(vars.size(), 4ul);
+   for (std::size_t i = 0; i < vars.size(); i++) {
+      EXPECT_EQ(vars[i], variablesClassification[i]);
+   }
+}
+
+TEST(RReader, ClassificationComputeVector)
+{
+   TrainClassificationModel();
+   const std::vector<float> x = {1.0, 2.0, 3.0, 4.0};
+   RReader model(modelClassification);
+   auto y = model.Compute(x);
+   EXPECT_EQ(y.size(), 1ul);
+}
+
+TEST(RReader, ClassificationComputeTensor)
+{
+   TrainClassificationModel();
+   ROOT::RDataFrame df("TreeS", filenameClassification);
+   auto x = AsTensor<float>(df, variablesClassification);
+
+   RReader model(modelClassification);
+   auto y = model.Compute(x);
+
+   const auto shapeX = x.GetShape();
+   const auto shapeY = y.GetShape();
+   EXPECT_EQ(shapeY.size(), 1ul);
+   EXPECT_EQ(shapeY[0], shapeX[0]);
+}
+
+TEST(RReader, ClassificationComputeDataFrame)
+{
+   TrainClassificationModel();
+   ROOT::RDataFrame df("TreeS", filenameClassification);
+   RReader model(modelClassification);
+   auto df2 = df.Define("y", Compute<4, float>(model), variablesClassification);
+   auto df3 = df2.Filter("y.size() == 1");
+   auto c = df3.Count();
+   auto y = df2.Take<std::vector<float>>("y");
+   EXPECT_EQ(y->size(), *c);
+}
+
+TEST(RReader, RegressionGetVariables)
+{
+   TrainRegressionModel();
+   RReader model(modelRegression);
+   auto vars = model.GetVariableNames();
+   EXPECT_EQ(vars.size(), 2ul);
+   for (std::size_t i = 0; i < vars.size(); i++) {
+      EXPECT_EQ(vars[i], variablesRegression[i]);
+   }
+}
+
+TEST(RReader, RegressionComputeVector)
+{
+   TrainRegressionModel();
+   const std::vector<float> x = {1.0, 2.0};
+   RReader model(modelRegression);
+   auto y = model.Compute(x);
+   EXPECT_EQ(y.size(), 1ul);
+}
+
+TEST(RReader, RegressionComputeTensor)
+{
+   TrainRegressionModel();
+   ROOT::RDataFrame df("TreeR", filenameRegression);
+   auto x = AsTensor<float>(df, variablesRegression);
+
+   RReader model(modelRegression);
+   auto y = model.Compute(x);
+
+   const auto shapeX = x.GetShape();
+   const auto shapeY = y.GetShape();
+   EXPECT_EQ(shapeY.size(), 1ul);
+   EXPECT_EQ(shapeY[0], shapeX[0]);
+}
+
+TEST(RReader, RegressionComputeDataFrame)
+{
+   TrainRegressionModel();
+   ROOT::RDataFrame df("TreeR", filenameRegression);
+   RReader model(modelRegression);
+   auto df2 = df.Define("y", Compute<2, float>(model), variablesRegression);
+   auto df3 = df2.Filter("y.size() == 1");
+   auto c = df3.Count();
+   auto y = df2.Take<std::vector<float>>("y");
+   EXPECT_EQ(y->size(), *c);
+}
+
+TEST(RReader, MulticlassGetVariables)
+{
+   TrainMulticlassModel();
+   RReader model(modelMulticlass);
+   auto vars = model.GetVariableNames();
+   EXPECT_EQ(vars.size(), 4ul);
+   for (std::size_t i = 0; i < vars.size(); i++) {
+      EXPECT_EQ(vars[i], variablesMulticlass[i]);
+   }
+}
+
+TEST(RReader, MulticlassComputeVector)
+{
+   TrainMulticlassModel();
+   const std::vector<float> x = {1.0, 2.0, 3.0, 4.0};
+   RReader model(modelMulticlass);
+   auto y = model.Compute(x);
+   EXPECT_EQ(y.size(), 4ul);
+}
+
+TEST(RReader, MulticlassComputeTensor)
+{
+   TrainMulticlassModel();
+   ROOT::RDataFrame df("TreeS", filenameMulticlass);
+   auto x = AsTensor<float>(df, variablesMulticlass);
+
+   RReader model(modelMulticlass);
+   auto y = model.Compute(x);
+
+   const auto shapeX = x.GetShape();
+   const auto shapeY = y.GetShape();
+   EXPECT_EQ(shapeY.size(), 2ul);
+   EXPECT_EQ(shapeY[0], shapeX[0]);
+   EXPECT_EQ(shapeY[1], 4ul);
+}
+
+TEST(RReader, MulticlassComputeDataFrame)
+{
+   TrainMulticlassModel();
+   ROOT::RDataFrame df("TreeS", filenameMulticlass);
+   RReader model(modelMulticlass);
+   auto df2 = df.Define("y", Compute<4, float>(model), variablesMulticlass);
+   auto df3 = df2.Filter("y.size() == 4");
+   auto c = df3.Count();
+   auto y = df2.Take<std::vector<float>>("y");
+   EXPECT_EQ(y->size(), *c);
+}
diff --git a/tutorials/tmva/tmva003_RReader.C b/tutorials/tmva/tmva003_RReader.C
new file mode 100644
index 00000000000..2eb9a7c3f86
--- /dev/null
+++ b/tutorials/tmva/tmva003_RReader.C
@@ -0,0 +1,109 @@
+/// \file
+/// \ingroup tutorial_tmva
+/// \notebook -nodraw
+/// This tutorial shows how to apply with the modern interfaces models saved in
+/// TMVA XML files.
+///
+/// \macro_code
+/// \macro_output
+///
+/// \date July 2019
+/// \author Stefan Wunsch
+
+using namespace TMVA::Experimental;
+
+void train(const std::string &filename)
+{
+   // Create factory
+   auto output = TFile::Open("TMVA.root", "RECREATE");
+   auto factory = new TMVA::Factory("tmva003",
+           output, "!V:!DrawProgressBar:AnalysisType=Classification");
+
+   // Open trees with signal and background events
+   auto data = TFile::Open(filename.c_str());
+   auto signal = (TTree *)data->Get("TreeS");
+   auto background = (TTree *)data->Get("TreeB");
+
+   // Add variables and register the trees with the dataloader
+   auto dataloader = new TMVA::DataLoader("tmva003_BDT");
+   const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
+   for (const auto &var : variables) {
+      dataloader->AddVariable(var);
+   }
+   dataloader->AddSignalTree(signal, 1.0);
+   dataloader->AddBackgroundTree(background, 1.0);
+   dataloader->PrepareTrainingAndTestTree("", "");
+
+   // Train a TMVA method
+   factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=300:MaxDepth=2");
+   factory->TrainAllMethods();
+}
+
+void tmva003_RReader()
+{
+   // First, let's train a model with TMVA.
+   const std::string filename = "http://root.cern.ch/files/tmva_class_example.root";
+   train(filename);
+
+   // Next, we load the model from the TMVA XML file.
+   RReader model("tmva003_BDT/weights/tmva003_BDT.weights.xml");
+
+   // In case you need a reminder of the names and order of the variables during
+   // training, you can ask the model for it.
+   auto variables = model.GetVariableNames();
+
+   // The model can now be applied in different scenarios:
+   // 1) Event-by-event inference
+   // 2) Batch inference on data of multiple events
+   // 3) Inference as part of an RDataFrame graph
+
+   // 1) Event-by-event inference
+   // The event-by-event inference takes the values of the variables as a std::vector<float>.
+   // Note that the return value is as well a std::vector<float> since the reader
+   // is also capable to process models with multiple outputs.
+   auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});
+   std::cout << "Single-event inference: " << prediction[0] << "\n\n";
+
+   // 2) Batch inference on data of multiple events
+   // For batch inference, the data needs to be structured as a matrix. For this
+   // purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame
+   // and the AsTensor utility to make the read-out from the ROOT file.
+   ROOT::RDataFrame df("TreeS", filename);
+   auto df2 = df.Range(3); // Read only a small subset of the dataset
+   auto x = AsTensor<float>(df2, variables);
+   auto y = model.Compute(x);
+
+   std::cout << "RTensor input for inference on data of multiple events:\n" << x << "\n\n";
+   std::cout << "Prediction performed on multiple events: " << y << "\n\n";
+
+   // 3) Perform inference as part of an RDataFrame graph
+   // We write a small lambda function that performs for us the inference on
+   // a dataframe to omit code duplication.
+   auto make_histo = [&](const std::string &treename) {
+      ROOT::RDataFrame df(treename, filename);
+      auto df2 = df.Define("y", Compute<4, float>(model), variables);
+      return df2.Histo1D({treename.c_str(), ";BDT score;N_{Events}", 30, -0.5, 0.5}, "y");
+   };
+
+   auto sig = make_histo("TreeS");
+   auto bkg = make_histo("TreeB");
+
+   // Make plot
+   gStyle->SetOptStat(0);
+   auto c = new TCanvas("", "", 800, 800);
+
+   sig->SetLineColor(kRed);
+   bkg->SetLineColor(kBlue);
+   sig->SetLineWidth(2);
+   bkg->SetLineWidth(2);
+   bkg->Draw("HIST");
+   sig->Draw("HIST SAME");
+
+   TLegend legend(0.7, 0.7, 0.89, 0.89);
+   legend.SetBorderSize(0);
+   legend.AddEntry("TreeS", "Signal", "l");
+   legend.AddEntry("TreeB", "Background", "l");
+   legend.Draw();
+
+   c->DrawClone();
+}
-- 
GitLab