Skip to content
Snippets Groups Projects
Commit 99d15ba5 authored by Kim Albertsson's avatar Kim Albertsson Committed by Lorenzo Moneta
Browse files

[TMVA] root-8988 -- Validate batchsize before training

In MethodDNN one could start training with a batch size larger than the
number of events in training leading to non-sensical output. TMVA now
properly warns (fatal warning) and suggests a fix (decrease batch size).
parent cce50df0
No related branches found
No related tags found
No related merge requests found
......@@ -659,6 +659,24 @@ void TMVA::MethodDNN::Train()
fIPyMaxIter = 100;
}
for (TTrainingSettings & settings : fTrainingSettings) {
size_t nValidationSamples = GetNumValidationSamples();
size_t nTrainingSamples = GetEventCollection(Types::kTraining).size() - nValidationSamples;
size_t nTestSamples = nValidationSamples;
if (nTrainingSamples < settings.batchSize or
nValidationSamples < settings.batchSize or
nTestSamples < settings.batchSize) {
Log() << kFATAL << "Number of samples in the datasets are train: "
<< nTrainingSamples << " valid: " << nValidationSamples
<< " test: " << nTestSamples << ". "
<< "One of these is smaller than the batch size of "
<< settings.batchSize << ". Please increase the batch"
<< " size to be at least the same size as the smallest"
<< " of these values." << Endl;
}
}
if (fArchitectureString == "GPU") {
TrainGpu();
if (!fExitFromTraining) fIPyMaxIter = fIPyCurrentIter;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment