Skip to content
Snippets Groups Projects
Commit 2824060d authored by Stefan Wunsch's avatar Stefan Wunsch
Browse files

[TMVA experimental] Add new reader interface

parent 45fad0d3
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,8 @@ if(dataframe)
TMVA/RTensor.hxx
TMVA/RTensorUtils.hxx
TMVA/RStandardScaler.hxx
TMVA/RReader.hxx
TMVA/RInferenceUtils.hxx
)
set(TMVA_EXTRA_DEPENDENCIES
ROOTDataFrame
......
#ifndef TMVA_RINFERENCEUTILS
#define TMVA_RINFERENCEUTILS
#include "ROOT/RIntegerSequence.hxx" // std::index_sequence
#include <utility> // std::forward
namespace TMVA {
namespace Experimental {
namespace Internal {
/// Compute helper
template <typename I, typename T, typename F>
class ComputeHelper;
template <std::size_t... N, typename T, typename F>
class ComputeHelper<std::index_sequence<N...>, T, F> {
template <std::size_t Idx>
using AlwaysT = T;
F fFunc;
public:
ComputeHelper(F &&f) : fFunc(std::forward<F>(f)) {}
auto operator()(AlwaysT<N>... args) -> decltype(fFunc.Compute({args...})) { return fFunc.Compute({args...}); }
};
} // namespace Internal
/// Helper to pass TMVA model to RDataFrame.Define nodes
template <std::size_t N, typename T, typename F>
auto Compute(F &&f) -> Internal::ComputeHelper<std::make_index_sequence<N>, T, F>
{
return Internal::ComputeHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f));
}
} // namespace Experimental
} // namespace TMVA
#endif // TMVA_RINFERENCEUTILS
#ifndef TMVA_RREADER
#define TMVA_RREADER
#include "TString.h"
#include "TXMLEngine.h"
#include "ROOT/RMakeUnique.hxx"
#include "TMVA/RTensor.hxx"
#include "TMVA/Reader.h"
#include <memory> // std::unique_ptr
#include <sstream> // std::stringstream
namespace TMVA {
namespace Experimental {
namespace Internal {
/// Internal definition of analysis types
enum AnalysisType : unsigned int { Undefined = 0, Classification, Regression, Multiclass };
/// Container for information extracted from TMVA XML config
struct XMLConfig {
unsigned int numVariables;
std::vector<std::string> variables;
unsigned int numClasses;
std::vector<std::string> classes;
AnalysisType analysisType;
XMLConfig()
: numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
analysisType(Internal::AnalysisType::Undefined)
{
}
};
/// Parse TMVA XML config
inline XMLConfig ParseXMLConfig(const std::string &filename)
{
XMLConfig c;
// Parse XML file and find root node
TXMLEngine xml;
auto xmldoc = xml.ParseFile(filename.c_str());
if (xmldoc == 0) {
std::stringstream ss;
ss << "Failed to open TMVA XML file "
<< filename << ".";
throw std::runtime_error(ss.str());
}
auto mainNode = xml.DocGetRootElement(xmldoc);
for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
const auto nodeName = std::string(xml.GetNodeName(node));
// Read out input variables
if (nodeName.compare("Variables") == 0) {
c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
c.variables = std::vector<std::string>(c.numVariables);
for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
}
}
// Read out output classes
else if (nodeName.compare("Classes") == 0) {
c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
c.classes.push_back(xml.GetAttr(thisNode, "Name"));
}
}
// Read out analysis type
else if (nodeName.compare("GeneralInfo") == 0) {
std::string analysisType = "";
for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
analysisType = xml.GetAttr(thisNode, "value");
}
}
if (analysisType.compare("Classification") == 0) {
c.analysisType = Internal::AnalysisType::Classification;
} else if (analysisType.compare("Regression") == 0) {
c.analysisType = Internal::AnalysisType::Regression;
} else if (analysisType.compare("Multiclass") == 0) {
c.analysisType = Internal::AnalysisType::Multiclass;
}
}
}
xml.FreeDoc(xmldoc);
// Error-handling
if (c.numVariables != c.variables.size() || c.numVariables == 0) {
std::stringstream ss;
ss << "Failed to parse input variables from TMVA config " << filename << ".";
throw std::runtime_error(ss.str());
}
if (c.numClasses != c.classes.size() || c.numClasses == 0) {
std::stringstream ss;
ss << "Failed to parse output classes from TMVA config " << filename << ".";
throw std::runtime_error(ss.str());
}
if (c.analysisType == Internal::AnalysisType::Undefined) {
std::stringstream ss;
ss << "Failed to parse analysis type from TMVA config " << filename << ".";
throw std::runtime_error(ss.str());
}
return c;
}
} // namespace Internal
/// TMVA::Reader legacy interface
class RReader {
private:
std::unique_ptr<Reader> fReader;
std::vector<float> fValues;
std::vector<std::string> fVariables;
unsigned int fNumClasses;
const char *name = "RReader";
Internal::AnalysisType fAnalysisType;
public:
/// Create TMVA model from XML file
RReader(const std::string &path)
{
// Load config
auto c = Internal::ParseXMLConfig(path);
fVariables = c.variables;
fAnalysisType = c.analysisType;
fNumClasses = c.numClasses;
// Setup reader
fReader = std::make_unique<Reader>("Silent");
const auto numVars = fVariables.size();
fValues = std::vector<float>(numVars);
for (std::size_t i = 0; i < numVars; i++) {
fReader->AddVariable(TString(fVariables[i]), &fValues[i]);
}
fReader->BookMVA(name, path.c_str());
}
/// Compute model prediction on vector
std::vector<float> Compute(const std::vector<float> &x)
{
if (x.size() != fVariables.size())
throw std::runtime_error("Size of input vector is not equal to number of variables.");
// Copy over inputs to memory used by TMVA reader
for (std::size_t i = 0; i < x.size(); i++) {
fValues[i] = x[i];
}
// Take lock to protect model evaluation
R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
// Evaluate TMVA model
// Classification
if (fAnalysisType == Internal::AnalysisType::Classification) {
return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
}
// Regression
else if (fAnalysisType == Internal::AnalysisType::Regression) {
return fReader->EvaluateRegression(name);
}
// Multiclass
else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
return fReader->EvaluateMulticlass(name);
}
// Throw error
else {
throw std::runtime_error("RReader has undefined analysis type.");
return std::vector<float>();
}
}
/// Compute model prediction on input RTensor
RTensor<float> Compute(RTensor<float> &x)
{
// Error-handling for input tensor
const auto shape = x.GetShape();
if (shape.size() != 2)
throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
const auto numEntries = shape[0];
const auto numVars = shape[1];
if (numVars != fVariables.size())
throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
// Define shape of output tensor based on analysis type
unsigned int numClasses = 1;
if (fAnalysisType == Internal::AnalysisType::Multiclass)
numClasses = fNumClasses;
RTensor<float> y({numEntries * numClasses});
if (fAnalysisType == Internal::AnalysisType::Multiclass)
y = y.Reshape({numEntries, numClasses});
// Fill output tensor
for (std::size_t i = 0; i < numEntries; i++) {
for (std::size_t j = 0; j < numVars; j++) {
fValues[j] = x(i, j);
}
R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
// Classification
if (fAnalysisType == Internal::AnalysisType::Classification) {
y(i) = fReader->EvaluateMVA(name);
}
// Regression
else if (fAnalysisType == Internal::AnalysisType::Regression) {
y(i) = fReader->EvaluateRegression(name)[0];
}
// Multiclass
else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
const auto p = fReader->EvaluateMulticlass(name);
for (std::size_t k = 0; k < numClasses; k++)
y(i, k) = p[k];
}
}
return y;
}
std::vector<std::string> GetVariableNames() { return fVariables; }
};
} // namespace Experimental
} // namespace TMVA
#endif // TMVA_RREADER
......@@ -18,11 +18,14 @@ ROOT_ADD_GTEST(TestRandomGenerator
LIBRARIES ${Libraries})
if(dataframe)
# RTensor
ROOT_ADD_GTEST(rtensor rtensor.cxx LIBRARIES ROOTVecOps TMVA)
ROOT_ADD_GTEST(rtensor-iterator rtensor_iterator.cxx LIBRARIES ROOTVecOps TMVA)
ROOT_ADD_GTEST(rtensor-utils rtensor_utils.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame)
# RStandardScaler
ROOT_ADD_GTEST(rstandardscaler rstandardscaler.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame)
# RReader
ROOT_ADD_GTEST(rreader rreader.cxx LIBRARIES ROOTVecOps TMVA ROOTDataFrame)
endif()
project(tmva-tests)
......
#include <gtest/gtest.h>
#include <TFile.h>
#include <TTree.h>
#include <TSystem.h>
#include <TMVA/Factory.h>
#include <TMVA/DataLoader.h>
#include <TMVA/RReader.hxx>
#include <TMVA/RInferenceUtils.hxx>
#include <TMVA/RTensor.hxx>
#include <TMVA/RTensorUtils.hxx>
using namespace TMVA::Experimental;
// Classification
static const std::string modelClassification = "RReaderClassification/weights/RReaderClassification_BDT.weights.xml";
static const std::string filenameClassification = "http://root.cern.ch/files/tmva_class_example.root";
static const std::vector<std::string> variablesClassification = {"var1", "var2", "var3", "var4"};
void TrainClassificationModel()
{
// Check for existing training
if (gSystem->mkdir("RReaderClassification") == -1) return;
// Create factory
auto output = TFile::Open("TMVA.root", "RECREATE");
auto factory = new TMVA::Factory("RReaderClassification",
output, "Silent:!V:!DrawProgressBar:AnalysisType=Classification");
// Open trees with signal and background events
const std::string filename = "http://root.cern.ch/files/tmva_class_example.root";
auto data = TFile::Open(filename.c_str());
auto signal = (TTree *)data->Get("TreeS");
auto background = (TTree *)data->Get("TreeB");
// Add variables and register the trees with the dataloader
auto dataloader = new TMVA::DataLoader("RReaderClassification");
const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
for (const auto &var : variables) {
dataloader->AddVariable(var);
}
dataloader->AddSignalTree(signal, 1.0);
dataloader->AddBackgroundTree(background, 1.0);
dataloader->PrepareTrainingAndTestTree("", "");
// Train a TMVA method
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=100:MaxDepth=2");
factory->TrainAllMethods();
output->Close();
}
// Regression
static const std::string modelRegression = "RReaderRegression/weights/RReaderRegression_BDTG.weights.xml";
static const std::string filenameRegression = "http://root.cern.ch/files/tmva_reg_example.root";
static const std::vector<std::string> variablesRegression = {"var1", "var2"};
void TrainRegressionModel()
{
// Check for existing training
if (gSystem->mkdir("RReaderRegression") == -1) return;
// Create factory
auto output = TFile::Open("TMVA.root", "RECREATE");
auto factory = new TMVA::Factory("RReaderRegression",
output, "Silent:!V:!DrawProgressBar:AnalysisType=Regression");
// Open trees with signal and background events
const std::string filename = "http://root.cern.ch/files/tmva_reg_example.root";
auto data = TFile::Open(filename.c_str());
auto tree = (TTree *)data->Get("TreeR");
// Add variables and register the trees with the dataloader
auto dataloader = new TMVA::DataLoader("RReaderRegression");
dataloader->AddVariable("var1");
dataloader->AddVariable("var2");
dataloader->AddTarget("fvalue");
dataloader->AddRegressionTree(tree, 1.0);
dataloader->PrepareTrainingAndTestTree("", "");
// Train a TMVA method
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDTG", "!V:!H:NTrees=100:MaxDepth=2");
factory->TrainAllMethods();
output->Close();
}
// Multiclass
static const std::string modelMulticlass = "RReaderMulticlass/weights/RReaderMulticlass_BDT.weights.xml";
static const std::string filenameMulticlass = "http://root.cern.ch/files/tmva_multiclass_example.root";
static const std::vector<std::string> variablesMulticlass = variablesClassification;
void TrainMulticlassModel()
{
// Check for existing training
if (gSystem->mkdir("RReaderMulticlass") == -1) return;
// Create factory
auto output = TFile::Open("TMVA.root", "RECREATE");
auto factory = new TMVA::Factory("RReaderMulticlass",
output, "Silent:!V:!DrawProgressBar:AnalysisType=Multiclass");
// Open trees with signal and background events
const std::string filename = "http://root.cern.ch/files/tmva_multiclass_example.root";
auto data = TFile::Open(filename.c_str());
auto signal = (TTree *)data->Get("TreeS");
auto background0 = (TTree *)data->Get("TreeB0");
auto background1 = (TTree *)data->Get("TreeB1");
auto background2 = (TTree *)data->Get("TreeB2");
// Add variables and register the trees with the dataloader
auto dataloader = new TMVA::DataLoader("RReaderMulticlass");
const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
for (const auto &var : variables) {
dataloader->AddVariable(var);
}
dataloader->AddTree(signal, "Signal");
dataloader->AddTree(background0, "Background_0");
dataloader->AddTree(background1, "Background_1");
dataloader->AddTree(background2, "Background_2");
dataloader->PrepareTrainingAndTestTree("", "");
// Train a TMVA method
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=100:MaxDepth=2:BoostType=Grad");
factory->TrainAllMethods();
output->Close();
}
TEST(RReader, ClassificationGetVariables)
{
TrainClassificationModel();
RReader model(modelClassification);
auto vars = model.GetVariableNames();
EXPECT_EQ(vars.size(), 4ul);
for (std::size_t i = 0; i < vars.size(); i++) {
EXPECT_EQ(vars[i], variablesClassification[i]);
}
}
TEST(RReader, ClassificationComputeVector)
{
TrainClassificationModel();
const std::vector<float> x = {1.0, 2.0, 3.0, 4.0};
RReader model(modelClassification);
auto y = model.Compute(x);
EXPECT_EQ(y.size(), 1ul);
}
TEST(RReader, ClassificationComputeTensor)
{
TrainClassificationModel();
ROOT::RDataFrame df("TreeS", filenameClassification);
auto x = AsTensor<float>(df, variablesClassification);
RReader model(modelClassification);
auto y = model.Compute(x);
const auto shapeX = x.GetShape();
const auto shapeY = y.GetShape();
EXPECT_EQ(shapeY.size(), 1ul);
EXPECT_EQ(shapeY[0], shapeX[0]);
}
TEST(RReader, ClassificationComputeDataFrame)
{
TrainClassificationModel();
ROOT::RDataFrame df("TreeS", filenameClassification);
RReader model(modelClassification);
auto df2 = df.Define("y", Compute<4, float>(model), variablesClassification);
auto df3 = df2.Filter("y.size() == 1");
auto c = df3.Count();
auto y = df2.Take<std::vector<float>>("y");
EXPECT_EQ(y->size(), *c);
}
TEST(RReader, RegressionGetVariables)
{
TrainRegressionModel();
RReader model(modelRegression);
auto vars = model.GetVariableNames();
EXPECT_EQ(vars.size(), 2ul);
for (std::size_t i = 0; i < vars.size(); i++) {
EXPECT_EQ(vars[i], variablesRegression[i]);
}
}
TEST(RReader, RegressionComputeVector)
{
TrainRegressionModel();
const std::vector<float> x = {1.0, 2.0};
RReader model(modelRegression);
auto y = model.Compute(x);
EXPECT_EQ(y.size(), 1ul);
}
TEST(RReader, RegressionComputeTensor)
{
TrainRegressionModel();
ROOT::RDataFrame df("TreeR", filenameRegression);
auto x = AsTensor<float>(df, variablesRegression);
RReader model(modelRegression);
auto y = model.Compute(x);
const auto shapeX = x.GetShape();
const auto shapeY = y.GetShape();
EXPECT_EQ(shapeY.size(), 1ul);
EXPECT_EQ(shapeY[0], shapeX[0]);
}
TEST(RReader, RegressionComputeDataFrame)
{
TrainRegressionModel();
ROOT::RDataFrame df("TreeR", filenameRegression);
RReader model(modelRegression);
auto df2 = df.Define("y", Compute<2, float>(model), variablesRegression);
auto df3 = df2.Filter("y.size() == 1");
auto c = df3.Count();
auto y = df2.Take<std::vector<float>>("y");
EXPECT_EQ(y->size(), *c);
}
TEST(RReader, MulticlassGetVariables)
{
TrainMulticlassModel();
RReader model(modelMulticlass);
auto vars = model.GetVariableNames();
EXPECT_EQ(vars.size(), 4ul);
for (std::size_t i = 0; i < vars.size(); i++) {
EXPECT_EQ(vars[i], variablesMulticlass[i]);
}
}
TEST(RReader, MulticlassComputeVector)
{
TrainMulticlassModel();
const std::vector<float> x = {1.0, 2.0, 3.0, 4.0};
RReader model(modelMulticlass);
auto y = model.Compute(x);
EXPECT_EQ(y.size(), 4ul);
}
TEST(RReader, MulticlassComputeTensor)
{
TrainMulticlassModel();
ROOT::RDataFrame df("TreeS", filenameMulticlass);
auto x = AsTensor<float>(df, variablesMulticlass);
RReader model(modelMulticlass);
auto y = model.Compute(x);
const auto shapeX = x.GetShape();
const auto shapeY = y.GetShape();
EXPECT_EQ(shapeY.size(), 2ul);
EXPECT_EQ(shapeY[0], shapeX[0]);
EXPECT_EQ(shapeY[1], 4ul);
}
TEST(RReader, MulticlassComputeDataFrame)
{
TrainMulticlassModel();
ROOT::RDataFrame df("TreeS", filenameMulticlass);
RReader model(modelMulticlass);
auto df2 = df.Define("y", Compute<4, float>(model), variablesMulticlass);
auto df3 = df2.Filter("y.size() == 4");
auto c = df3.Count();
auto y = df2.Take<std::vector<float>>("y");
EXPECT_EQ(y->size(), *c);
}
/// \file
/// \ingroup tutorial_tmva
/// \notebook -nodraw
/// This tutorial shows how to apply with the modern interfaces models saved in
/// TMVA XML files.
///
/// \macro_code
/// \macro_output
///
/// \date July 2019
/// \author Stefan Wunsch
using namespace TMVA::Experimental;
void train(const std::string &filename)
{
// Create factory
auto output = TFile::Open("TMVA.root", "RECREATE");
auto factory = new TMVA::Factory("tmva003",
output, "!V:!DrawProgressBar:AnalysisType=Classification");
// Open trees with signal and background events
auto data = TFile::Open(filename.c_str());
auto signal = (TTree *)data->Get("TreeS");
auto background = (TTree *)data->Get("TreeB");
// Add variables and register the trees with the dataloader
auto dataloader = new TMVA::DataLoader("tmva003_BDT");
const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
for (const auto &var : variables) {
dataloader->AddVariable(var);
}
dataloader->AddSignalTree(signal, 1.0);
dataloader->AddBackgroundTree(background, 1.0);
dataloader->PrepareTrainingAndTestTree("", "");
// Train a TMVA method
factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=300:MaxDepth=2");
factory->TrainAllMethods();
}
void tmva003_RReader()
{
// First, let's train a model with TMVA.
const std::string filename = "http://root.cern.ch/files/tmva_class_example.root";
train(filename);
// Next, we load the model from the TMVA XML file.
RReader model("tmva003_BDT/weights/tmva003_BDT.weights.xml");
// In case you need a reminder of the names and order of the variables during
// training, you can ask the model for it.
auto variables = model.GetVariableNames();
// The model can now be applied in different scenarios:
// 1) Event-by-event inference
// 2) Batch inference on data of multiple events
// 3) Inference as part of an RDataFrame graph
// 1) Event-by-event inference
// The event-by-event inference takes the values of the variables as a std::vector<float>.
// Note that the return value is as well a std::vector<float> since the reader
// is also capable to process models with multiple outputs.
auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});
std::cout << "Single-event inference: " << prediction[0] << "\n\n";
// 2) Batch inference on data of multiple events
// For batch inference, the data needs to be structured as a matrix. For this
// purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame
// and the AsTensor utility to make the read-out from the ROOT file.
ROOT::RDataFrame df("TreeS", filename);
auto df2 = df.Range(3); // Read only a small subset of the dataset
auto x = AsTensor<float>(df2, variables);
auto y = model.Compute(x);
std::cout << "RTensor input for inference on data of multiple events:\n" << x << "\n\n";
std::cout << "Prediction performed on multiple events: " << y << "\n\n";
// 3) Perform inference as part of an RDataFrame graph
// We write a small lambda function that performs for us the inference on
// a dataframe to omit code duplication.
auto make_histo = [&](const std::string &treename) {
ROOT::RDataFrame df(treename, filename);
auto df2 = df.Define("y", Compute<4, float>(model), variables);
return df2.Histo1D({treename.c_str(), ";BDT score;N_{Events}", 30, -0.5, 0.5}, "y");
};
auto sig = make_histo("TreeS");
auto bkg = make_histo("TreeB");
// Make plot
gStyle->SetOptStat(0);
auto c = new TCanvas("", "", 800, 800);
sig->SetLineColor(kRed);
bkg->SetLineColor(kBlue);
sig->SetLineWidth(2);
bkg->SetLineWidth(2);
bkg->Draw("HIST");
sig->Draw("HIST SAME");
TLegend legend(0.7, 0.7, 0.89, 0.89);
legend.SetBorderSize(0);
legend.AddEntry("TreeS", "Signal", "l");
legend.AddEntry("TreeB", "Background", "l");
legend.Draw();
c->DrawClone();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment