Skip to content
Snippets Groups Projects
Commit 5e8eb508 authored by Enric Tejedor Saavedra's avatar Enric Tejedor Saavedra
Browse files

Support JITted expressions with branch.leaf syntax

parent deef6825
No related branches found
No related tags found
No related merge requests found
......@@ -210,6 +210,8 @@ TActionBase *BuildAndBook(const ColumnNames_t &bl, const std::shared_ptr<double>
}
/****** end BuildAndBook ******/
void Replace(std::string &s, const std::string what, const std::string withWhat);
std::vector<std::string> FindUsedColumnNames(std::string_view, TObjArray *, const std::vector<std::string> &);
using TmpBranchBasePtr_t = std::shared_ptr<TCustomColumnBase>;
......
......@@ -18,6 +18,7 @@
#include <TRegexp.h>
#include <TString.h>
#include <TTree.h>
#include <TBranchElement.h>
#include <iosfwd>
#include <stdexcept>
......@@ -48,6 +49,26 @@ namespace TDF {
// the one in the vector
class TActionBase;
void AddToList(ColumnNames_t &bNames, TTree &t, TBranch *b, std::string prefix)
{
for (auto sb : *b->GetListOfBranches()) {
TBranch *subBranch = static_cast<TBranch*>(sb);
auto subBranchName = std::string(subBranch->GetName());
auto fullName = prefix + subBranchName;
if (t.GetBranch(fullName.c_str()))
bNames.push_back(fullName);
else if (t.GetBranch(subBranchName.c_str()))
bNames.push_back(subBranchName);
std::string newPrefix;
if (!prefix.empty())
newPrefix = fullName + ".";
AddToList(bNames, t, subBranch, newPrefix);
}
}
void GetBranchNamesImpl(TTree &t, std::set<std::string> &bNamesReg, ColumnNames_t &bNames,
std::set<TTree *> &analysedTrees)
{
......@@ -56,12 +77,44 @@ void GetBranchNamesImpl(TTree &t, std::set<std::string> &bNamesReg, ColumnNames_
return;
}
auto branches = t.GetListOfBranches();
const auto branches = t.GetListOfBranches();
if (branches) {
for (auto branchObj : *branches) {
auto name = branchObj->GetName();
if (bNamesReg.insert(name).second) {
bNames.emplace_back(name);
std::string prefix = "";
for (auto b : *branches) {
TBranch *branch = static_cast<TBranch*>(b);
auto branchName = std::string(branch->GetName());
if (branch->IsA() == TBranch::Class()) {
// Leaf list
auto listOfLeaves = branch->GetListOfLeaves();
if (listOfLeaves->GetEntries() == 1) {
if (bNamesReg.insert(branchName).second)
bNames.push_back(branchName);
}
for (auto leaf : *listOfLeaves) {
auto leafName = std::string(static_cast<TLeaf*>(leaf)->GetName());
auto fullName = branchName + "." + leafName;
if (bNamesReg.insert(fullName).second)
bNames.push_back(fullName);
}
} else {
// TBranchElement
// Check if there is explicit or implicit dot in the name
bool dotIsImplied = false;
auto be = dynamic_cast<TBranchElement*>(b);
// TClonesArray (3) and STL collection (4)
if (be->GetType() == 3 || be->GetType() == 4)
dotIsImplied = true;
if (dotIsImplied || branchName.back() == '.')
AddToList(bNames, t, branch, "");
else
AddToList(bNames, t, branch, branchName + ".");
if (bNamesReg.insert(branchName).second)
bNames.push_back(branchName);
}
}
}
......@@ -155,20 +208,10 @@ SelectColumns(unsigned int nRequiredNames, const ColumnNames_t &names, const Col
bool IsTreeLeaf(TTree &t, const std::string &leaf)
{
// TODO understand why GetBranch is also needed (run the tests without, inspect failures)
if (t.GetBranch(leaf.c_str()) != nullptr)
return true;
if (t.GetLeaf(leaf.c_str()) != nullptr)
return true;
auto lastDot = leaf.find_last_of('.');
if (lastDot != std::string::npos) {
std::string leafWithSlash(leaf);
leafWithSlash[lastDot] = '/';
if (t.GetLeaf(leafWithSlash.c_str()) != nullptr)
return true;
}
return false;
const auto sep = leaf.find_last_of("/");
if (sep != std::string::npos)
return t.GetLeaf(leaf.substr(0, sep).c_str(), leaf.substr(sep+1).c_str()) != nullptr;
return t.GetLeaf(nullptr, leaf.c_str()) != nullptr;
}
ColumnNames_t FindUnknownColumns(const ColumnNames_t &requiredCols, TTree *tree, const ColumnNames_t &definedCols,
......@@ -195,6 +238,16 @@ bool IsInternalColumn(std::string_view colName)
return 0 == colName.find("tdf") && '_' == colName.back();
}
// Replace all the occurrences of a string by another string
void Replace(std::string &s, const std::string what, const std::string withWhat)
{
size_t idx = 0;
while ((idx = s.find(what, idx)) != std::string::npos) {
s.replace(idx, what.size(), withWhat);
idx += withWhat.size();
}
}
// Match expression against names of branches passed as parameter
// Return vector of names of the branches used in the expression
std::vector<std::string> FindUsedColumnNames(std::string_view expression, const ColumnNames_t &branches,
......@@ -219,7 +272,10 @@ std::vector<std::string> FindUsedColumnNames(std::string_view expression, const
// Check which tree branches match
for (auto &brName : branches) {
std::string bNameRegexContent = regexBit + brName + regexBit;
// Replace "." with "\." for a correct match of sub-branches/leaves
auto escapedBrName = brName;
Replace(escapedBrName, std::string("."), std::string("\\."));
std::string bNameRegexContent = regexBit + escapedBrName + regexBit;
TRegexp bNameRegex(bNameRegexContent.c_str());
if (-1 != bNameRegex.Index(paddedExpr.c_str(), &paddedExprLen)) {
usedBranches.emplace_back(brName);
......@@ -263,6 +319,9 @@ Long_t JitTransformation(void *thisPtr, std::string_view methodName, std::string
{
const auto &dsColumns = ds ? ds->GetColumnNames() : ColumnNames_t{};
auto usedBranches = FindUsedColumnNames(expression, branches, customColumns, dsColumns, aliasMap);
auto brId = 0U;
std::vector<std::string> dotlessBranches;
auto dotlessExpr = std::string(expression);
auto exprNeedsVariables = !usedBranches.empty();
// Move to the preparation of the jitting
......@@ -285,21 +344,28 @@ Long_t JitTransformation(void *thisPtr, std::string_view methodName, std::string
auto tmpBrIt = tmpBookedBranches.find(realBrName);
auto tmpBr = tmpBrIt == tmpBookedBranches.end() ? nullptr : tmpBrIt->second.get();
auto brTypeName = ColumnName2ColumnTypeName(realBrName, tree, tmpBr, ds);
dummyDecl << brTypeName << " " << brName << ";\n";
auto finalBrName = brName;
if (brName.find(".") != std::string::npos) {
// If the branch name contains dots, replace it with a temporary one
finalBrName = std::string("__tdf_arg") + std::to_string(brId++);
Replace(dotlessExpr, brName, finalBrName);
}
dummyDecl << brTypeName << " " << finalBrName << ";\n";
dotlessBranches.emplace_back(std::move(finalBrName));
usedBranchesTypes.emplace_back(brTypeName);
}
}
TRegexp re("[^a-zA-Z0-9_]return[^a-zA-Z0-9_]");
int exprSize = expression.size();
bool hasReturnStmt = re.Index(std::string(expression), &exprSize) != -1;
int exprSize = dotlessExpr.size();
bool hasReturnStmt = re.Index(dotlessExpr, &exprSize) != -1;
// Now that branches are declared as variables, put the body of the
// lambda in dummyDecl and close scopes of f and namespace __tdf_N
if (hasReturnStmt)
dummyDecl << expression << "\n;};}";
dummyDecl << dotlessExpr << "\n;};}";
else
dummyDecl << "return " << expression << "\n;};}";
dummyDecl << "return " << dotlessExpr << "\n;};}";
// Try to declare the dummy lambda, error out if it does not compile
if (!gInterpreter->Declare(dummyDecl.str().c_str())) {
......@@ -316,15 +382,15 @@ Long_t JitTransformation(void *thisPtr, std::string_view methodName, std::string
// It can't be const reference in general, as users might want/need to call non-const methods on the values
// Here we do not replace anything: the name of the parameters of the lambda does not need to be the real
// column name, and sometimes it has to be an alias to compile (e.g. "b_a" as alias for "b.a")
ss << usedBranchesTypes[i] << "& " << usedBranches[i] << ", ";
ss << usedBranchesTypes[i] << "& " << dotlessBranches[i] << ", ";
}
if (!usedBranchesTypes.empty())
ss.seekp(-2, ss.cur);
if (hasReturnStmt)
ss << "){\n" << expression << "\n}";
ss << "){\n" << dotlessExpr << "\n}";
else
ss << "){return " << expression << "\n;}";
ss << "){return " << dotlessExpr << "\n;}";
auto filterLambda = ss.str();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment