00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069 #include <iomanip>
00070 #include <iostream>
00071 #include <fstream>
00072 #include <sstream>
00073 #include <cstdlib>
00074 #include <algorithm>
00075 #include <limits>
00076
00077 #include "TROOT.h"
00078 #include "TSystem.h"
00079 #include "TObjString.h"
00080 #include "TQObject.h"
00081 #include "TSpline.h"
00082 #include "TMatrix.h"
00083 #include "TMath.h"
00084 #include "TFile.h"
00085 #include "TKey.h"
00086 #include "TGraph.h"
00087 #include "Riostream.h"
00088 #include "TXMLEngine.h"
00089
00090 #include "TMVA/MsgLogger.h"
00091 #include "TMVA/MethodBase.h"
00092 #include "TMVA/Config.h"
00093 #include "TMVA/Timer.h"
00094 #include "TMVA/RootFinder.h"
00095 #include "TMVA/PDF.h"
00096 #include "TMVA/VariableIdentityTransform.h"
00097 #include "TMVA/VariableDecorrTransform.h"
00098 #include "TMVA/VariablePCATransform.h"
00099 #include "TMVA/VariableGaussTransform.h"
00100 #include "TMVA/VariableNormalizeTransform.h"
00101 #include "TMVA/Version.h"
00102 #include "TMVA/TSpline1.h"
00103 #include "TMVA/Ranking.h"
00104 #include "TMVA/Factory.h"
00105 #include "TMVA/Tools.h"
00106 #include "TMVA/ResultsClassification.h"
00107 #include "TMVA/ResultsRegression.h"
00108 #include "TMVA/ResultsMulticlass.h"
00109
00110 ClassImp(TMVA::MethodBase)
00111
00112 using std::endl;
00113
00114 const Int_t MethodBase_MaxIterations_ = 200;
00115 const Bool_t Use_Splines_for_Eff_ = kTRUE;
00116
00117 const Int_t NBIN_HIST_PLOT = 100;
00118 const Int_t NBIN_HIST_HIGH = 10000;
00119
00120 #ifdef _WIN32
00121
00122 #pragma warning ( disable : 4355 )
00123 #endif
00124
00125
00126 TMVA::MethodBase::MethodBase( const TString& jobName,
00127 Types::EMVA methodType,
00128 const TString& methodTitle,
00129 DataSetInfo& dsi,
00130 const TString& theOption,
00131 TDirectory* theBaseDir) :
00132 IMethod(),
00133 Configurable ( theOption ),
00134 fTmpEvent ( 0 ),
00135 fAnalysisType ( Types::kNoAnalysisType ),
00136 fRegressionReturnVal ( 0 ),
00137 fMulticlassReturnVal ( 0 ),
00138 fDisableWriting ( kFALSE ),
00139 fDataSetInfo ( dsi ),
00140 fSignalReferenceCut ( 0.5 ),
00141 fVariableTransformType ( Types::kSignal ),
00142 fJobName ( jobName ),
00143 fMethodName ( methodTitle ),
00144 fMethodType ( methodType ),
00145 fTestvar ( "" ),
00146 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
00147 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
00148 fConstructedFromWeightFile ( kFALSE ),
00149 fBaseDir ( 0 ),
00150 fMethodBaseDir ( theBaseDir ),
00151 fWeightFile ( "" ),
00152 fDefaultPDF ( 0 ),
00153 fMVAPdfS ( 0 ),
00154 fMVAPdfB ( 0 ),
00155 fSplS ( 0 ),
00156 fSplB ( 0 ),
00157 fSpleffBvsS ( 0 ),
00158 fSplTrainS ( 0 ),
00159 fSplTrainB ( 0 ),
00160 fSplTrainEffBvsS ( 0 ),
00161 fVarTransformString ( "None" ),
00162 fTransformation ( dsi, methodTitle ),
00163 fVerbose ( kFALSE ),
00164 fVerbosityLevelString ( "Default" ),
00165 fHelp ( kFALSE ),
00166 fHasMVAPdfs ( kFALSE ),
00167 fIgnoreNegWeightsInTraining( kFALSE ),
00168 fSignalClass ( 0 ),
00169 fBackgroundClass ( 0 ),
00170 fSplRefS ( 0 ),
00171 fSplRefB ( 0 ),
00172 fSplTrainRefS ( 0 ),
00173 fSplTrainRefB ( 0 ),
00174 fSetupCompleted (kFALSE)
00175 {
00176
00177 SetTestvarName();
00178
00179
00180 SetWeightFileDir( gConfig().GetIONames().fWeightFileDir );
00181 gSystem->MakeDirectory( GetWeightFileDir() );
00182 }
00183
00184
00185 TMVA::MethodBase::MethodBase( Types::EMVA methodType,
00186 DataSetInfo& dsi,
00187 const TString& weightFile,
00188 TDirectory* theBaseDir ) :
00189 IMethod(),
00190 Configurable(""),
00191 fTmpEvent ( 0 ),
00192 fAnalysisType ( Types::kNoAnalysisType ),
00193 fRegressionReturnVal ( 0 ),
00194 fMulticlassReturnVal ( 0 ),
00195 fDataSetInfo ( dsi ),
00196 fSignalReferenceCut ( 0.5 ),
00197 fVariableTransformType ( Types::kSignal ),
00198 fJobName ( "" ),
00199 fMethodName ( "MethodBase" ),
00200 fMethodType ( methodType ),
00201 fTestvar ( "" ),
00202 fTMVATrainingVersion ( 0 ),
00203 fROOTTrainingVersion ( 0 ),
00204 fConstructedFromWeightFile ( kTRUE ),
00205 fBaseDir ( theBaseDir ),
00206 fMethodBaseDir ( 0 ),
00207 fWeightFile ( weightFile ),
00208 fDefaultPDF ( 0 ),
00209 fMVAPdfS ( 0 ),
00210 fMVAPdfB ( 0 ),
00211 fSplS ( 0 ),
00212 fSplB ( 0 ),
00213 fSpleffBvsS ( 0 ),
00214 fSplTrainS ( 0 ),
00215 fSplTrainB ( 0 ),
00216 fSplTrainEffBvsS ( 0 ),
00217 fVarTransformString ( "None" ),
00218 fTransformation ( dsi, "" ),
00219 fVerbose ( kFALSE ),
00220 fVerbosityLevelString ( "Default" ),
00221 fHelp ( kFALSE ),
00222 fHasMVAPdfs ( kFALSE ),
00223 fIgnoreNegWeightsInTraining( kFALSE ),
00224 fSignalClass ( 0 ),
00225 fBackgroundClass ( 0 ),
00226 fSplRefS ( 0 ),
00227 fSplRefB ( 0 ),
00228 fSplTrainRefS ( 0 ),
00229 fSplTrainRefB ( 0 ),
00230 fSetupCompleted (kFALSE)
00231 {
00232
00233
00234 }
00235
00236
00237 TMVA::MethodBase::~MethodBase( void )
00238 {
00239
00240 if (!fSetupCompleted) Log() << kFATAL << "Calling destructor of method which got never setup" << Endl;
00241
00242
00243 if (fInputVars != 0) { fInputVars->clear(); delete fInputVars; }
00244 if (fRanking != 0) delete fRanking;
00245
00246
00247 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
00248 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
00249 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
00250
00251
00252 if (fSplS) { delete fSplS; fSplS = 0; }
00253 if (fSplB) { delete fSplB; fSplB = 0; }
00254 if (fSpleffBvsS) { delete fSpleffBvsS; fSpleffBvsS = 0; }
00255 if (fSplRefS) { delete fSplRefS; fSplRefS = 0; }
00256 if (fSplRefB) { delete fSplRefB; fSplRefB = 0; }
00257 if (fSplTrainRefS) { delete fSplTrainRefS; fSplTrainRefS = 0; }
00258 if (fSplTrainRefB) { delete fSplTrainRefB; fSplTrainRefB = 0; }
00259 if (fSplTrainEffBvsS) { delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
00260
00261 for (Int_t i = 0; i < 2; i++ ) {
00262 if (fEventCollections.at(i)) {
00263 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
00264 it != fEventCollections.at(i)->end(); it++) {
00265 delete (*it);
00266 }
00267 delete fEventCollections.at(i);
00268 fEventCollections.at(i) = 0;
00269 }
00270 }
00271
00272 if (fRegressionReturnVal) delete fRegressionReturnVal;
00273 if (fMulticlassReturnVal) delete fMulticlassReturnVal;
00274 }
00275
00276
00277 void TMVA::MethodBase::SetupMethod()
00278 {
00279
00280
00281 if (fSetupCompleted) Log() << kFATAL << "Calling SetupMethod for the second time" << Endl;
00282 InitBase();
00283 DeclareBaseOptions();
00284 Init();
00285 DeclareOptions();
00286 fSetupCompleted = kTRUE;
00287 }
00288
00289
00290 void TMVA::MethodBase::ProcessSetup()
00291 {
00292
00293
00294
00295 ProcessBaseOptions();
00296 ProcessOptions();
00297 }
00298
00299
00300 void TMVA::MethodBase::CheckSetup()
00301 {
00302
00303
00304 CheckForUnusedOptions();
00305 }
00306
00307
00308 void TMVA::MethodBase::InitBase()
00309 {
00310
00311 SetConfigDescription( "Configuration options for classifier architecture and tuning" );
00312
00313 fNbins = gConfig().fVariablePlotting.fNbinsXOfROCCurve;
00314 fNbinsH = NBIN_HIST_HIGH;
00315
00316 fSplTrainS = 0;
00317 fSplTrainB = 0;
00318 fSplTrainEffBvsS = 0;
00319 fMeanS = -1;
00320 fMeanB = -1;
00321 fRmsS = -1;
00322 fRmsB = -1;
00323 fXmin = DBL_MAX;
00324 fXmax = -DBL_MAX;
00325 fTxtWeightsOnly = kTRUE;
00326 fSplRefS = 0;
00327 fSplRefB = 0;
00328
00329 fTrainTime = -1.;
00330 fTestTime = -1.;
00331
00332 fRanking = 0;
00333
00334
00335 fInputVars = new std::vector<TString>;
00336 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00337 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
00338 }
00339 fRegressionReturnVal = 0;
00340 fMulticlassReturnVal = 0;
00341
00342 fEventCollections.resize( 2 );
00343 fEventCollections.at(0) = 0;
00344 fEventCollections.at(1) = 0;
00345
00346
00347 ResetThisBase();
00348
00349
00350 if (DataInfo().GetClassInfo("Signal") != 0) {
00351 fSignalClass = DataInfo().GetClassInfo("Signal")->GetNumber();
00352 }
00353 if (DataInfo().GetClassInfo("Background") != 0) {
00354 fBackgroundClass = DataInfo().GetClassInfo("Background")->GetNumber();
00355 }
00356
00357 SetConfigDescription( "Configuration options for MVA method" );
00358 SetConfigName( TString("Method") + GetMethodTypeName() );
00359 }
00360
00361
00362 void TMVA::MethodBase::DeclareBaseOptions()
00363 {
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378 DeclareOptionRef( fVerbose, "V", "Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
00379
00380 DeclareOptionRef( fVerbosityLevelString="Default", "VerbosityLevel", "Verbosity level" );
00381 AddPreDefVal( TString("Default") );
00382 AddPreDefVal( TString("Debug") );
00383 AddPreDefVal( TString("Verbose") );
00384 AddPreDefVal( TString("Info") );
00385 AddPreDefVal( TString("Warning") );
00386 AddPreDefVal( TString("Error") );
00387 AddPreDefVal( TString("Fatal") );
00388
00389
00390
00391 fTxtWeightsOnly = kTRUE;
00392 fNormalise = kFALSE;
00393
00394 DeclareOptionRef( fVarTransformString, "VarTransform", "List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
00395
00396 DeclareOptionRef( fHelp, "H", "Print method-specific help message" );
00397
00398 DeclareOptionRef( fHasMVAPdfs, "CreateMVAPdfs", "Create PDFs for classifier outputs (signal and background)" );
00399
00400 DeclareOptionRef( fIgnoreNegWeightsInTraining, "IgnoreNegWeightsInTraining",
00401 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
00402 }
00403
00404
00405 void TMVA::MethodBase::ProcessBaseOptions()
00406 {
00407
00408
00409 if (HasMVAPdfs()) {
00410
00411
00412
00413 fDefaultPDF = new PDF( TString(GetName())+"_PDF", GetOptions(), "MVAPdf" );
00414 fDefaultPDF->DeclareOptions();
00415 fDefaultPDF->ParseOptions();
00416 fDefaultPDF->ProcessOptions();
00417 fMVAPdfB = new PDF( TString(GetName())+"_PDFBkg", fDefaultPDF->GetOptions(), "MVAPdfBkg", fDefaultPDF );
00418 fMVAPdfB->DeclareOptions();
00419 fMVAPdfB->ParseOptions();
00420 fMVAPdfB->ProcessOptions();
00421 fMVAPdfS = new PDF( TString(GetName())+"_PDFSig", fMVAPdfB->GetOptions(), "MVAPdfSig", fDefaultPDF );
00422 fMVAPdfS->DeclareOptions();
00423 fMVAPdfS->ParseOptions();
00424 fMVAPdfS->ProcessOptions();
00425
00426
00427 SetOptions( fMVAPdfS->GetOptions() );
00428 }
00429
00430 CreateVariableTransforms( fVarTransformString );
00431
00432 if (!HasMVAPdfs()) {
00433 if (fDefaultPDF!= 0) { delete fDefaultPDF; fDefaultPDF = 0; }
00434 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
00435 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
00436 }
00437
00438 if (fVerbose) {
00439 fVerbosityLevelString = TString("Verbose");
00440 Log().SetMinType( kVERBOSE );
00441 }
00442 else if (fVerbosityLevelString == "Debug" ) Log().SetMinType( kDEBUG );
00443 else if (fVerbosityLevelString == "Verbose" ) Log().SetMinType( kVERBOSE );
00444 else if (fVerbosityLevelString == "Info" ) Log().SetMinType( kINFO );
00445 else if (fVerbosityLevelString == "Warning" ) Log().SetMinType( kWARNING );
00446 else if (fVerbosityLevelString == "Error" ) Log().SetMinType( kERROR );
00447 else if (fVerbosityLevelString == "Fatal" ) Log().SetMinType( kFATAL );
00448 else if (fVerbosityLevelString != "Default" ) {
00449 Log() << kFATAL << "<ProcessOptions> Verbosity level type '"
00450 << fVerbosityLevelString << "' unknown." << Endl;
00451 }
00452 }
00453
00454
00455 void TMVA::MethodBase::CreateVariableTransforms(const TString& trafoDefinition )
00456 {
00457 if (trafoDefinition != "None") {
00458 TList* trList = gTools().ParseFormatLine( trafoDefinition, "," );
00459 TListIter trIt(trList);
00460 while (TObjString* os = (TObjString*)trIt()) {
00461 Int_t idxCls = -1;
00462
00463 TList* trClsList = gTools().ParseFormatLine( os->GetString(), "_" );
00464 TListIter trClsIt(trClsList);
00465 const TString& trName = ((TObjString*)trClsList->At(0))->GetString();
00466
00467 if (trClsList->GetEntries() > 1) {
00468 TString trCls = "AllClasses";
00469 ClassInfo *ci = NULL;
00470 trCls = ((TObjString*)trClsList->At(1))->GetString();
00471 if (trCls != "AllClasses") {
00472 ci = DataInfo().GetClassInfo( trCls );
00473 if (ci == NULL)
00474 Log() << kFATAL << "Class " << trCls << " not known for variable transformation "
00475 << trName << ", please check." << Endl;
00476 else
00477 idxCls = ci->GetNumber();
00478 }
00479 }
00480
00481 if (trName == "D" || trName == "Deco" || trName == "Decorrelate")
00482 GetTransformationHandler().AddTransformation( new VariableDecorrTransform ( DataInfo()) , idxCls );
00483 else if (trName == "P" || trName == "PCA")
00484 GetTransformationHandler().AddTransformation( new VariablePCATransform ( DataInfo()), idxCls );
00485 else if (trName == "U" || trName == "Uniform")
00486 GetTransformationHandler().AddTransformation( new VariableGaussTransform ( DataInfo(),"Uniform"), idxCls );
00487 else if (trName == "G" || trName == "Gauss")
00488 GetTransformationHandler().AddTransformation( new VariableGaussTransform ( DataInfo()), idxCls );
00489 else if (trName == "N" || trName == "Norm" || trName == "Normalise" || trName == "Normalize")
00490 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform( DataInfo()), idxCls );
00491 else
00492 Log() << kFATAL << "<ProcessOptions> Variable transform '"
00493 << trName << "' unknown." << Endl;
00494 ClassInfo* clsInfo = DataInfo().GetClassInfo(idxCls);
00495 if( clsInfo )
00496 Log() << kINFO << " create Transformation " << trName << " with reference class " <<clsInfo->GetName() << "=("<< idxCls <<")"<<Endl;
00497 else
00498 Log() << kINFO << " create Transformation " << trName << " with events of all classes." << Endl;
00499
00500 }
00501 }
00502 }
00503
00504
00505 void TMVA::MethodBase::DeclareCompatibilityOptions()
00506 {
00507 DeclareOptionRef( fNormalise=kFALSE, "Normalise", "Normalise input variables" );
00508 DeclareOptionRef( fUseDecorr=kFALSE, "D", "Use-decorrelated-variables flag" );
00509 DeclareOptionRef( fVariableTransformTypeString="Signal", "VarTransformType",
00510 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
00511 AddPreDefVal( TString("Signal") );
00512 AddPreDefVal( TString("Background") );
00513 DeclareOptionRef( fTxtWeightsOnly=kTRUE, "TxtWeightFilesOnly", "If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
00514 DeclareOptionRef( fVerbosityLevelString="Default", "VerboseLevel", "Verbosity level" );
00515 AddPreDefVal( TString("Default") );
00516 AddPreDefVal( TString("Debug") );
00517 AddPreDefVal( TString("Verbose") );
00518 AddPreDefVal( TString("Info") );
00519 AddPreDefVal( TString("Warning") );
00520 AddPreDefVal( TString("Error") );
00521 AddPreDefVal( TString("Fatal") );
00522 DeclareOptionRef( fNbinsMVAPdf = 60, "NbinsMVAPdf", "Number of bins used for the PDFs of classifier outputs" );
00523 DeclareOptionRef( fNsmoothMVAPdf = 2, "NsmoothMVAPdf", "Number of smoothing iterations for classifier PDFs" );
00524 }
00525
00526
00527
00528 std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString , TString )
00529 {
00530
00531
00532
00533
00534
00535
00536
00537 Log() << kWARNING << "Parameter optimization is not yet implemented for method "
00538 << GetName() << Endl;
00539 Log() << kWARNING << "Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
00540
00541 std::map<TString,Double_t> tunedParameters;
00542 tunedParameters.size();
00543 return tunedParameters;
00544
00545 }
00546
00547
00548 void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> )
00549 {
00550
00551
00552
00553 }
00554
00555
00556 void TMVA::MethodBase::TrainMethod()
00557 {
00558 Data()->SetCurrentType(Types::kTraining);
00559
00560
00561 if (Help()) PrintHelpMessage();
00562
00563
00564 BaseDir()->cd();
00565
00566 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
00567
00568
00569 Log() << kINFO << "Begin training" << Endl;
00570 Long64_t nEvents = Data()->GetNEvents();
00571 Timer traintimer( nEvents, GetName(), kTRUE );
00572 Train();
00573 Log() << kINFO << "End of training " << Endl;
00574 SetTrainTime(traintimer.ElapsedSeconds());
00575 Log() << kINFO << "Elapsed time for training with " << nEvents << " events: "
00576 << traintimer.GetElapsedTime() << " " << Endl;
00577
00578 Log() << kINFO << "Create MVA output for ";
00579
00580
00581 if (DoMulticlass()){
00582 Log() << "Multiclass classification on training sample" << Endl;
00583 AddMulticlassOutput(Types::kTraining);
00584 }
00585 else if (!DoRegression()) {
00586
00587 Log() << "classification on training sample" << Endl;
00588 AddClassifierOutput(Types::kTraining);
00589 if (HasMVAPdfs()) {
00590 CreateMVAPdfs();
00591 AddClassifierOutputProb(Types::kTraining);
00592 }
00593
00594 } else {
00595
00596 Log() << "regression on training sample" << Endl;
00597 AddRegressionOutput( Types::kTraining );
00598
00599 if (HasMVAPdfs() ) {
00600 Log() << "Create PDFs" << Endl;
00601 CreateMVAPdfs();
00602 }
00603 }
00604
00605
00606
00607 if( !fDisableWriting ) WriteStateToFile();
00608
00609
00610 if ((!DoRegression()) && (!fDisableWriting)) MakeClass();
00611
00612
00613
00614 BaseDir()->cd();
00615 WriteMonitoringHistosToFile();
00616 }
00617
00618
00619 void TMVA::MethodBase::GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t& stddev, Double_t& stddev90Percent ) const
00620 {
00621 if (!DoRegression()) Log() << kFATAL << "Trying to use GetRegressionDeviation() with a classification job" << Endl;
00622 Log() << kINFO << "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
00623 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), Types::kTesting, Types::kRegression);
00624 bool truncate = false;
00625 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
00626 stddev = sqrt(h1->GetMean());
00627 truncate = true;
00628 Double_t yq[1], xq[]={0.9};
00629 h1->GetQuantiles(1,yq,xq);
00630 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
00631 stddev90Percent = sqrt(h2->GetMean());
00632 delete h1;
00633 delete h2;
00634 }
00635
00636
00637 void TMVA::MethodBase::AddRegressionOutput(Types::ETreeType type)
00638 {
00639
00640
00641 Data()->SetCurrentType(type);
00642
00643 Log() << kINFO << "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
00644
00645 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
00646
00647 Long64_t nEvents = Data()->GetNEvents();
00648
00649
00650 Timer timer( nEvents, GetName(), kTRUE );
00651
00652 Log() << kINFO << "Evaluation of " << GetMethodName() << " on "
00653 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
00654
00655 regRes->Resize( nEvents );
00656 for (Int_t ievt=0; ievt<nEvents; ievt++) {
00657 Data()->SetCurrentEvent(ievt);
00658 std::vector< Float_t > vals = GetRegressionValues();
00659 regRes->SetValue( vals, ievt );
00660 timer.DrawProgressBar( ievt );
00661 }
00662
00663 Log() << kINFO << "Elapsed time for evaluation of " << nEvents << " events: "
00664 << timer.GetElapsedTime() << " " << Endl;
00665
00666
00667 if (type==Types::kTesting)
00668 SetTestTime(timer.ElapsedSeconds());
00669
00670 TString histNamePrefix(GetTestvarName());
00671 histNamePrefix += (type==Types::kTraining?"train":"test");
00672 regRes->CreateDeviationHistograms( histNamePrefix );
00673 }
00674
00675
00676 void TMVA::MethodBase::AddMulticlassOutput(Types::ETreeType type)
00677 {
00678
00679
00680 Data()->SetCurrentType(type);
00681
00682 Log() << kINFO << "Create results for " << (type==Types::kTraining?"training":"testing") << Endl;
00683
00684 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
00685 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
00686
00687 Long64_t nEvents = Data()->GetNEvents();
00688
00689
00690 Timer timer( nEvents, GetName(), kTRUE );
00691
00692 Log() << kINFO << "Multiclass evaluation of " << GetMethodName() << " on "
00693 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
00694
00695 resMulticlass->Resize( nEvents );
00696 for (Int_t ievt=0; ievt<nEvents; ievt++) {
00697 Data()->SetCurrentEvent(ievt);
00698 std::vector< Float_t > vals = GetMulticlassValues();
00699 resMulticlass->SetValue( vals, ievt );
00700 timer.DrawProgressBar( ievt );
00701 }
00702
00703 Log() << kINFO << "Elapsed time for evaluation of " << nEvents << " events: "
00704 << timer.GetElapsedTime() << " " << Endl;
00705
00706
00707 if (type==Types::kTesting)
00708 SetTestTime(timer.ElapsedSeconds());
00709
00710 TString histNamePrefix(GetTestvarName());
00711 histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
00712 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbins, fNbinsH );
00713 }
00714
00715
00716
00717
00718 void TMVA::MethodBase::NoErrorCalc(Double_t* const err, Double_t* const errUpper) {
00719 if(err) *err=-1;
00720 if(errUpper) *errUpper=-1;
00721 }
00722
00723
00724 Double_t TMVA::MethodBase::GetMvaValue( const Event* const ev, Double_t* err, Double_t* errUpper ) {
00725 fTmpEvent = ev;
00726 Double_t val = GetMvaValue(err, errUpper);
00727 fTmpEvent = 0;
00728 return val;
00729 }
00730
00731
00732 void TMVA::MethodBase::AddClassifierOutput( Types::ETreeType type )
00733 {
00734
00735
00736 Data()->SetCurrentType(type);
00737
00738 ResultsClassification* clRes =
00739 (ResultsClassification*)Data()->GetResults(GetMethodName(), type, Types::kClassification );
00740
00741 Long64_t nEvents = Data()->GetNEvents();
00742
00743
00744 Timer timer( nEvents, GetName(), kTRUE );
00745
00746 Log() << kINFO << "Evaluation of " << GetMethodName() << " on "
00747 << (type==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
00748
00749 clRes->Resize( nEvents );
00750 for (Int_t ievt=0; ievt<nEvents; ievt++) {
00751
00752 SetCurrentEvent(ievt);
00753 clRes->SetValue( GetMvaValue(), ievt );
00754
00755
00756 Int_t modulo = Int_t(nEvents/100);
00757 if( modulo <= 0 ) modulo = 1;
00758 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
00759 }
00760
00761 Log() << kINFO << "Elapsed time for evaluation of " << nEvents << " events: "
00762 << timer.GetElapsedTime() << " " << Endl;
00763
00764
00765 if (type==Types::kTesting)
00766 SetTestTime(timer.ElapsedSeconds());
00767
00768 }
00769
00770
00771 void TMVA::MethodBase::AddClassifierOutputProb( Types::ETreeType type )
00772 {
00773
00774
00775 Data()->SetCurrentType(type);
00776
00777 ResultsClassification* mvaProb =
00778 (ResultsClassification*)Data()->GetResults(TString("prob_")+GetMethodName(), type, Types::kClassification );
00779
00780 Long64_t nEvents = Data()->GetNEvents();
00781
00782
00783 Timer timer( nEvents, GetName(), kTRUE );
00784
00785 Log() << kINFO << "Evaluation of " << GetMethodName() << " on "
00786 << (type==Types::kTraining?"training":"testing") << " sample" << Endl;
00787
00788 mvaProb->Resize( nEvents );
00789 for (Int_t ievt=0; ievt<nEvents; ievt++) {
00790
00791 Data()->SetCurrentEvent(ievt);
00792 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
00793 if (proba < 0) break;
00794 mvaProb->SetValue( proba, ievt );
00795
00796
00797 Int_t modulo = Int_t(nEvents/100);
00798 if( modulo <= 0 ) modulo = 1;
00799 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
00800 }
00801
00802 Log() << kINFO << "Elapsed time for evaluation of " << nEvents << " events: "
00803 << timer.GetElapsedTime() << " " << Endl;
00804 }
00805
00806
00807 void TMVA::MethodBase::TestRegression( Double_t& bias, Double_t& biasT,
00808 Double_t& dev, Double_t& devT,
00809 Double_t& rms, Double_t& rmsT,
00810 Double_t& mInf, Double_t& mInfT,
00811 Double_t& corr,
00812 Types::ETreeType type )
00813 {
00814
00815
00816
00817
00818
00819
00820
00821 Types::ETreeType savedType = Data()->GetCurrentType();
00822 Data()->SetCurrentType(type);
00823
00824 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
00825 Double_t sumw = 0;
00826 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0;
00827 const Int_t nevt = GetNEvents();
00828 Float_t* rV = new Float_t[nevt];
00829 Float_t* tV = new Float_t[nevt];
00830 Float_t* wV = new Float_t[nevt];
00831 Float_t xmin = 1e30, xmax = -1e30;
00832 for (Long64_t ievt=0; ievt<nevt; ievt++) {
00833
00834 const Event* ev = Data()->GetEvent(ievt);
00835 Float_t t = ev->GetTarget(0);
00836 Float_t w = ev->GetWeight();
00837 Float_t r = GetRegressionValues()[0];
00838 Float_t d = (r-t);
00839
00840
00841 xmin = TMath::Min(xmin, TMath::Min(t, r));
00842 xmax = TMath::Max(xmax, TMath::Max(t, r));
00843
00844
00845 rV[ievt] = r;
00846 tV[ievt] = t;
00847 wV[ievt] = w;
00848
00849
00850 sumw += w;
00851 bias += w * d;
00852 dev += w * TMath::Abs(d);
00853 rms += w * d * d;
00854
00855
00856 m1 += t*w; s1 += t*t*w;
00857 m2 += r*w; s2 += r*r*w;
00858 s12 += t*r;
00859 }
00860
00861
00862 bias /= sumw;
00863 dev /= sumw;
00864 rms /= sumw;
00865 rms = TMath::Sqrt(rms - bias*bias);
00866
00867
00868 m1 /= sumw;
00869 m2 /= sumw;
00870 corr = s12/sumw - m1*m2;
00871 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
00872
00873
00874 TH2F* hist = new TH2F( "hist", "hist", 150, xmin, xmax, 100, xmin, xmax );
00875 TH2F* histT = new TH2F( "histT", "histT", 150, xmin, xmax, 100, xmin, xmax );
00876
00877
00878 Double_t devMax = bias + 2*rms;
00879 Double_t devMin = bias - 2*rms;
00880 sumw = 0;
00881 int ic=0;
00882 for (Long64_t ievt=0; ievt<nevt; ievt++) {
00883 Float_t d = (rV[ievt] - tV[ievt]);
00884 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
00885 if (d >= devMin && d <= devMax) {
00886 sumw += wV[ievt];
00887 biasT += wV[ievt] * d;
00888 devT += wV[ievt] * TMath::Abs(d);
00889 rmsT += wV[ievt] * d * d;
00890 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
00891 ic++;
00892 }
00893 }
00894 biasT /= sumw;
00895 devT /= sumw;
00896 rmsT /= sumw;
00897 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
00898 mInf = gTools().GetMutualInformation( *hist );
00899 mInfT = gTools().GetMutualInformation( *histT );
00900
00901 delete hist;
00902 delete histT;
00903
00904 delete [] rV;
00905 delete [] tV;
00906 delete [] wV;
00907
00908 Data()->SetCurrentType(savedType);
00909 }
00910
00911
00912
00913 void TMVA::MethodBase::TestMulticlass()
00914 {
00915
00916
00917 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
00918 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in TestMulticlass, exiting."<<Endl;
00919 Log() << kINFO << "Determine optimal multiclass cuts for test data..." << Endl;
00920 for(UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls){
00921 resMulticlass->GetBestMultiClassCuts(icls);
00922 }
00923 }
00924
00925
00926
00927 void TMVA::MethodBase::TestClassification()
00928 {
00929
00930 Data()->SetCurrentType(Types::kTesting);
00931
00932 ResultsClassification* mvaRes = dynamic_cast<ResultsClassification*>
00933 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
00934
00935
00936 if (0==mvaRes && !(GetMethodTypeName().Contains("Cuts"))) {
00937 Log() << "mvaRes " << mvaRes << " GetMethodTypeName " << GetMethodTypeName()
00938 << " contains " << !(GetMethodTypeName().Contains("Cuts")) << Endl;
00939 Log() << kFATAL << "<TestInit> Test variable " << GetTestvarName()
00940 << " not found in tree" << Endl;
00941 }
00942
00943
00944 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
00945 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
00946
00947
00948 Double_t nrms = 10;
00949 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
00950 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
00951
00952
00953 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
00954
00955
00956
00957
00958 Double_t sxmax = fXmax+0.00001;
00959
00960
00961
00962 TH1* mva_s = new TH1F( GetTestvarName() + "_S",GetTestvarName() + "_S", fNbins, fXmin, sxmax );
00963 TH1* mva_b = new TH1F( GetTestvarName() + "_B",GetTestvarName() + "_B", fNbins, fXmin, sxmax );
00964 mvaRes->Store(mva_s, "MVA_S");
00965 mvaRes->Store(mva_b, "MVA_B");
00966 mva_s->Sumw2();
00967 mva_b->Sumw2();
00968
00969 TH1* proba_s = 0;
00970 TH1* proba_b = 0;
00971 TH1* rarity_s = 0;
00972 TH1* rarity_b = 0;
00973 if (HasMVAPdfs()) {
00974
00975 proba_s = new TH1F( GetTestvarName() + "_Proba_S", GetTestvarName() + "_Proba_S", fNbins, 0.0, 1.0 );
00976 proba_b = new TH1F( GetTestvarName() + "_Proba_B", GetTestvarName() + "_Proba_B", fNbins, 0.0, 1.0 );
00977 mvaRes->Store(proba_s, "Prob_S");
00978 mvaRes->Store(proba_b, "Prob_B");
00979 proba_s->Sumw2();
00980 proba_b->Sumw2();
00981
00982
00983 rarity_s = new TH1F( GetTestvarName() + "_Rarity_S", GetTestvarName() + "_Rarity_S", fNbins, 0.0, 1.0 );
00984 rarity_b = new TH1F( GetTestvarName() + "_Rarity_B", GetTestvarName() + "_Rarity_B", fNbins, 0.0, 1.0 );
00985 mvaRes->Store(rarity_s, "Rar_S");
00986 mvaRes->Store(rarity_b, "Rar_B");
00987 rarity_s->Sumw2();
00988 rarity_b->Sumw2();
00989 }
00990
00991
00992 TH1* mva_eff_s = new TH1F( GetTestvarName() + "_S_high", GetTestvarName() + "_S_high", fNbinsH, fXmin, sxmax );
00993 TH1* mva_eff_b = new TH1F( GetTestvarName() + "_B_high", GetTestvarName() + "_B_high", fNbinsH, fXmin, sxmax );
00994 mvaRes->Store(mva_eff_s, "MVA_HIGHBIN_S");
00995 mvaRes->Store(mva_eff_b, "MVA_HIGHBIN_B");
00996 mva_eff_s->Sumw2();
00997 mva_eff_b->Sumw2();
00998
00999
01000 ResultsClassification* mvaProb = dynamic_cast<ResultsClassification*>
01001 (Data()->GetResults( TString("prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
01002
01003 Log() << kINFO << "Loop over test events and fill histograms with classifier response..." << Endl;
01004 if (mvaProb) Log() << kINFO << "Also filling probability and rarity histograms (on request)..." << Endl;
01005 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
01006
01007 const Event* ev = GetEvent(ievt);
01008 Float_t v = (*mvaRes)[ievt][0];
01009 Float_t w = ev->GetWeight();
01010
01011 if (DataInfo().IsSignal(ev)) {
01012 mva_s ->Fill( v, w );
01013 if (mvaProb) {
01014 proba_s->Fill( (*mvaProb)[ievt][0], w );
01015 rarity_s->Fill( GetRarity( v ), w );
01016 }
01017
01018 mva_eff_s ->Fill( v, w );
01019 }
01020 else {
01021 mva_b ->Fill( v, w );
01022 if (mvaProb) {
01023 proba_b->Fill( (*mvaProb)[ievt][0], w );
01024 rarity_b->Fill( GetRarity( v ), w );
01025 }
01026 mva_eff_b ->Fill( v, w );
01027 }
01028 }
01029
01030 gTools().NormHist( mva_s );
01031 gTools().NormHist( mva_b );
01032 gTools().NormHist( proba_s );
01033 gTools().NormHist( proba_b );
01034 gTools().NormHist( rarity_s );
01035 gTools().NormHist( rarity_b );
01036 gTools().NormHist( mva_eff_s );
01037 gTools().NormHist( mva_eff_b );
01038
01039
01040 if (fSplS) { delete fSplS; fSplS = 0; }
01041 if (fSplB) { delete fSplB; fSplB = 0; }
01042 fSplS = new PDF( TString(GetName()) + " PDF Sig", mva_s, PDF::kSpline2 );
01043 fSplB = new PDF( TString(GetName()) + " PDF Bkg", mva_b, PDF::kSpline2 );
01044 }
01045
01046
01047 void TMVA::MethodBase::WriteStateToStream( std::ostream& tf ) const
01048 {
01049
01050
01051
01052 TString prefix = "";
01053 UserGroup_t * userInfo = gSystem->GetUserInfo();
01054
01055 tf << prefix << "#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << endl << prefix << endl;
01056 tf << prefix << "Method : " << GetMethodTypeName() << "::" << GetMethodName() << endl;
01057 tf.setf(std::ios::left);
01058 tf << prefix << "TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() << " ["
01059 << GetTrainingTMVAVersionCode() << "]" << endl;
01060 tf << prefix << "ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() << " ["
01061 << GetTrainingROOTVersionCode() << "]" << endl;
01062 tf << prefix << "Creator : " << userInfo->fUser << endl;
01063 tf << prefix << "Date : "; TDatime *d = new TDatime; tf << d->AsString() << endl; delete d;
01064 tf << prefix << "Host : " << gSystem->GetBuildNode() << endl;
01065 tf << prefix << "Dir : " << gSystem->WorkingDirectory() << endl;
01066 tf << prefix << "Training events: " << Data()->GetNTrainingEvents() << endl;
01067
01068 TString analysisType(((const_cast<TMVA::MethodBase*>(this)->GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification"));
01069
01070 tf << prefix << "Analysis type : " << "[" << ((GetAnalysisType()==Types::kRegression) ? "Regression" : "Classification") << "]" << endl;
01071 tf << prefix << endl;
01072
01073 delete userInfo;
01074
01075
01076 tf << prefix << endl << prefix << "#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << endl << prefix << endl;
01077 WriteOptionsToStream( tf, prefix );
01078 tf << prefix << endl;
01079
01080
01081 tf << prefix << endl << prefix << "#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << endl << prefix << endl;
01082 WriteVarsToStream( tf, prefix );
01083 tf << prefix << endl;
01084 }
01085
01086
01087 void TMVA::MethodBase::AddInfoItem( void* gi, const TString& name, const TString& value) const
01088 {
01089
01090 void* it = gTools().AddChild(gi,"Info");
01091 gTools().AddAttr(it,"name", name);
01092 gTools().AddAttr(it,"value", value);
01093 }
01094
01095
01096 void TMVA::MethodBase::AddOutput( Types::ETreeType type, Types::EAnalysisType analysisType ) {
01097 if (analysisType == Types::kRegression) {
01098 AddRegressionOutput( type );
01099 } else if (analysisType == Types::kMulticlass ){
01100 AddMulticlassOutput( type );
01101 } else {
01102 AddClassifierOutput( type );
01103 if (HasMVAPdfs())
01104 AddClassifierOutputProb( type );
01105 }
01106 }
01107
01108
01109 void TMVA::MethodBase::WriteStateToXML( void* parent ) const
01110 {
01111
01112
01113
01114 if (!parent) return;
01115
01116 UserGroup_t* userInfo = gSystem->GetUserInfo();
01117
01118 void* gi = gTools().AddChild(parent, "GeneralInfo");
01119 AddInfoItem( gi, "TMVA Release", GetTrainingTMVAVersionString() + " [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) + "]" );
01120 AddInfoItem( gi, "ROOT Release", GetTrainingROOTVersionString() + " [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) + "]");
01121 AddInfoItem( gi, "Creator", userInfo->fUser);
01122 TDatime dt; AddInfoItem( gi, "Date", dt.AsString());
01123 AddInfoItem( gi, "Host", gSystem->GetBuildNode() );
01124 AddInfoItem( gi, "Dir", gSystem->WorkingDirectory());
01125 AddInfoItem( gi, "Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
01126 AddInfoItem( gi, "TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(this)->GetTrainTime()));
01127
01128 Types::EAnalysisType aType = const_cast<TMVA::MethodBase*>(this)->GetAnalysisType();
01129 TString analysisType((aType==Types::kRegression) ? "Regression" :
01130 (aType==Types::kMulticlass ? "Multiclass" : "Classification"));
01131 AddInfoItem( gi, "AnalysisType", analysisType );
01132 delete userInfo;
01133
01134
01135 AddOptionsXMLTo( parent );
01136
01137
01138 AddVarsXMLTo( parent );
01139
01140
01141 if(!fDisableWriting)
01142 AddSpectatorsXMLTo( parent );
01143
01144
01145 if(DoMulticlass())
01146 AddClassesXMLTo(parent);
01147
01148
01149 if(DoRegression())
01150 AddTargetsXMLTo(parent);
01151
01152
01153 GetTransformationHandler().AddXMLTo( parent );
01154
01155
01156 void* pdfs = gTools().AddChild(parent, "MVAPdfs");
01157 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
01158 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
01159
01160
01161 AddWeightsXMLTo( parent );
01162 }
01163
01164
01165 void TMVA::MethodBase::ReadStateFromStream( TFile& rf )
01166 {
01167
01168
01169
01170 Bool_t addDirStatus = TH1::AddDirectoryStatus();
01171 TH1::AddDirectory( 0 );
01172 fMVAPdfS = (TMVA::PDF*)rf.Get( "MVA_PDF_Signal" );
01173 fMVAPdfB = (TMVA::PDF*)rf.Get( "MVA_PDF_Background" );
01174
01175 TH1::AddDirectory( addDirStatus );
01176
01177 ReadWeightsFromStream( rf );
01178
01179 SetTestvarName();
01180 }
01181
01182
01183 void TMVA::MethodBase::WriteStateToFile() const
01184 {
01185
01186
01187
01188
01189
01190 TString tfname( GetWeightFileName() );
01191
01192
01193 TString xmlfname( tfname ); xmlfname.ReplaceAll( ".txt", ".xml" );
01194 Log() << kINFO << "Creating weight file in xml format: "
01195 << gTools().Color("lightblue") << xmlfname << gTools().Color("reset") << Endl;
01196 void* doc = gTools().xmlengine().NewDoc();
01197 void* rootnode = gTools().AddChild(0,"MethodSetup", "", true);
01198 gTools().xmlengine().DocSetRootElement(doc,rootnode);
01199 gTools().AddAttr(rootnode,"Method", GetMethodTypeName() + "::" + GetMethodName());
01200 WriteStateToXML(rootnode);
01201 gTools().xmlengine().SaveDoc(doc,xmlfname);
01202 gTools().xmlengine().FreeDoc(doc);
01203 }
01204
01205
01206 void TMVA::MethodBase::ReadStateFromFile()
01207 {
01208
01209
01210
01211
01212 TString tfname(GetWeightFileName());
01213
01214 Log() << kINFO << "Reading weight file: "
01215 << gTools().Color("lightblue") << tfname << gTools().Color("reset") << Endl;
01216
01217 if (tfname.EndsWith(".xml") ) {
01218 void* doc = gTools().xmlengine().ParseFile(tfname);
01219 void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
01220 ReadStateFromXML(rootnode);
01221 gTools().xmlengine().FreeDoc(doc);
01222 }
01223 else {
01224 filebuf fb;
01225 fb.open(tfname.Data(),ios::in);
01226 if (!fb.is_open()) {
01227 Log() << kFATAL << "<ReadStateFromFile> "
01228 << "Unable to open input weight file: " << tfname << Endl;
01229 }
01230 istream fin(&fb);
01231 ReadStateFromStream(fin);
01232 fb.close();
01233 }
01234 if (!fTxtWeightsOnly) {
01235
01236 TString rfname( tfname ); rfname.ReplaceAll( ".txt", ".root" );
01237 Log() << kINFO << "Reading root weight file: "
01238 << gTools().Color("lightblue") << rfname << gTools().Color("reset") << Endl;
01239 TFile* rfile = TFile::Open( rfname, "READ" );
01240 ReadStateFromStream( *rfile );
01241 rfile->Close();
01242 }
01243 }
01244
01245 void TMVA::MethodBase::ReadStateFromXMLString( const char* xmlstr ) {
01246
01247
01248 #if (ROOT_SVN_REVISION >= 32259) && (ROOT_VERSION_CODE >= 334336) // 5.26/00
01249 void* doc = gTools().xmlengine().ParseString(xmlstr);
01250 void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
01251 ReadStateFromXML(rootnode);
01252 gTools().xmlengine().FreeDoc(doc);
01253 #else
01254 Log() << kFATAL << "Method MethodBase::ReadStateFromXMLString( const char* xmlstr = "
01255 << xmlstr << " ) is not available for ROOT versions prior to 5.26/00." << Endl;
01256 #endif
01257
01258 return;
01259 }
01260
01261
01262 void TMVA::MethodBase::ReadStateFromXML( void* methodNode )
01263 {
01264 TString fullMethodName;
01265 gTools().ReadAttr( methodNode, "Method", fullMethodName );
01266 fMethodName = fullMethodName(fullMethodName.Index("::")+2,fullMethodName.Length());
01267
01268
01269 Log().SetSource( GetName() );
01270 Log() << kINFO << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
01271
01272
01273 SetTestvarName();
01274
01275 TString nodeName("");
01276 void* ch = gTools().GetChild(methodNode);
01277 while (ch!=0) {
01278 nodeName = TString( gTools().GetName(ch) );
01279
01280 if (nodeName=="GeneralInfo") {
01281
01282
01283 TString name(""),val("");
01284 void* antypeNode = gTools().GetChild(ch);
01285 while (antypeNode) {
01286 gTools().ReadAttr( antypeNode, "name", name );
01287
01288 if (name == "TrainingTime")
01289 gTools().ReadAttr( antypeNode, "value", fTrainTime );
01290
01291 if (name == "AnalysisType") {
01292 gTools().ReadAttr( antypeNode, "value", val );
01293 val.ToLower();
01294 if (val == "regression" ) SetAnalysisType( Types::kRegression );
01295 else if (val == "classification" ) SetAnalysisType( Types::kClassification );
01296 else if (val == "multiclass" ) SetAnalysisType( Types::kMulticlass );
01297 else Log() << kFATAL << "Analysis type " << val << " is not known." << Endl;
01298 }
01299
01300 if (name == "TMVA Release" || name == "TMVA" ){
01301 TString s;
01302 gTools().ReadAttr( antypeNode, "value", s);
01303 fTMVATrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
01304 Log() << kINFO << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
01305 }
01306
01307 if (name == "ROOT Release" || name == "ROOT" ){
01308 TString s;
01309 gTools().ReadAttr( antypeNode, "value", s);
01310 fROOTTrainingVersion = TString(s(s.Index("[")+1,s.Index("]")-s.Index("[")-1)).Atoi();
01311 Log() << kINFO << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
01312 }
01313 antypeNode = gTools().GetNextChild(antypeNode);
01314 }
01315 }
01316 else if (nodeName=="Options") {
01317 ReadOptionsFromXML(ch);
01318 ParseOptions();
01319
01320 }
01321 else if (nodeName=="Variables") {
01322 ReadVariablesFromXML(ch);
01323 }
01324 else if (nodeName=="Spectators") {
01325 ReadSpectatorsFromXML(ch);
01326 }
01327 else if (nodeName=="Classes") {
01328 if(DataInfo().GetNClasses()==0 && DoMulticlass())
01329 ReadClassesFromXML(ch);
01330 }
01331 else if (nodeName=="Targets") {
01332 if(DataInfo().GetNTargets()==0 && DoRegression())
01333 ReadTargetsFromXML(ch);
01334 }
01335 else if (nodeName=="Transformations") {
01336 GetTransformationHandler().ReadFromXML(ch);
01337 }
01338 else if (nodeName=="MVAPdfs") {
01339 TString pdfname;
01340 if (fMVAPdfS) { delete fMVAPdfS; fMVAPdfS=0; }
01341 if (fMVAPdfB) { delete fMVAPdfB; fMVAPdfB=0; }
01342 void* pdfnode = gTools().GetChild(ch);
01343 if (pdfnode) {
01344 gTools().ReadAttr(pdfnode, "Name", pdfname);
01345 fMVAPdfS = new PDF(pdfname);
01346 fMVAPdfS->ReadXML(pdfnode);
01347 pdfnode = gTools().GetNextChild(pdfnode);
01348 gTools().ReadAttr(pdfnode, "Name", pdfname);
01349 fMVAPdfB = new PDF(pdfname);
01350 fMVAPdfB->ReadXML(pdfnode);
01351 }
01352 }
01353 else if (nodeName=="Weights") {
01354 ReadWeightsFromXML(ch);
01355 }
01356 else {
01357 Log() << kWARNING << "Unparsed XML node: '" << nodeName << "'" << Endl;
01358 }
01359 ch = gTools().GetNextChild(ch);
01360
01361 }
01362
01363
01364 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
01365 }
01366
01367
01368 void TMVA::MethodBase::ReadStateFromStream( std::istream& fin )
01369 {
01370
01371 char buf[512];
01372
01373
01374 SetAnalysisType(Types::kClassification);
01375
01376
01377
01378 GetLine(fin,buf);
01379 while (!TString(buf).BeginsWith("Method")) GetLine(fin,buf);
01380 TString namestr(buf);
01381
01382 TString methodType = namestr(0,namestr.Index("::"));
01383 methodType = methodType(methodType.Last(' '),methodType.Length());
01384 methodType = methodType.Strip(TString::kLeading);
01385
01386 TString methodName = namestr(namestr.Index("::")+2,namestr.Length());
01387 methodName = methodName.Strip(TString::kLeading);
01388 if (methodName == "") methodName = methodType;
01389 fMethodName = methodName;
01390
01391 Log() << kINFO << "Read method \"" << GetMethodName() << "\" of type \"" << GetMethodTypeName() << "\"" << Endl;
01392
01393
01394 Log().SetSource( GetName() );
01395
01396
01397
01398
01399
01400
01401
01402
01403
01404
01405
01406
01407 GetLine(fin,buf);
01408 while (!TString(buf).BeginsWith("#OPT")) GetLine(fin,buf);
01409 ReadOptionsFromStream(fin);
01410 ParseOptions();
01411
01412
01413 fin.getline(buf,512);
01414 while (!TString(buf).BeginsWith("#VAR")) fin.getline(buf,512);
01415 ReadVarsFromStream(fin);
01416
01417
01418 ProcessOptions();
01419
01420 if(IsNormalised()) {
01421 VariableNormalizeTransform* norm = (VariableNormalizeTransform*)
01422 GetTransformationHandler().AddTransformation( new VariableNormalizeTransform(DataInfo()), -1 );
01423 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
01424 }
01425 VariableTransformBase *varTrafo(0), *varTrafo2(0);
01426 if ( fVarTransformString == "None") {
01427 if (fUseDecorr)
01428 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
01429 } else if ( fVarTransformString == "Decorrelate" ) {
01430 varTrafo = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
01431 } else if ( fVarTransformString == "PCA" ) {
01432 varTrafo = GetTransformationHandler().AddTransformation( new VariablePCATransform(DataInfo()), -1 );
01433 } else if ( fVarTransformString == "Uniform" ) {
01434 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo(),"Uniform"), -1 );
01435 } else if ( fVarTransformString == "Gauss" ) {
01436 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
01437 } else if ( fVarTransformString == "GaussDecorr" ) {
01438 varTrafo = GetTransformationHandler().AddTransformation( new VariableGaussTransform(DataInfo()), -1 );
01439 varTrafo2 = GetTransformationHandler().AddTransformation( new VariableDecorrTransform(DataInfo()), -1 );
01440 } else {
01441 Log() << kFATAL << "<ProcessOptions> Variable transform '"
01442 << fVarTransformString << "' unknown." << Endl;
01443 }
01444
01445 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
01446 fin.getline(buf,512);
01447 while (!TString(buf).BeginsWith("#MAT")) fin.getline(buf,512);
01448 if(varTrafo) {
01449 TString trafo(fVariableTransformTypeString); trafo.ToLower();
01450 varTrafo->ReadTransformationFromStream(fin, trafo );
01451 }
01452 if(varTrafo2) {
01453 TString trafo(fVariableTransformTypeString); trafo.ToLower();
01454 varTrafo2->ReadTransformationFromStream(fin, trafo );
01455 }
01456 }
01457
01458
01459 if (HasMVAPdfs()) {
01460
01461 fin.getline(buf,512);
01462 while (!TString(buf).BeginsWith("#MVAPDFS")) fin.getline(buf,512);
01463 if (fMVAPdfS != 0) { delete fMVAPdfS; fMVAPdfS = 0; }
01464 if (fMVAPdfB != 0) { delete fMVAPdfB; fMVAPdfB = 0; }
01465 fMVAPdfS = new PDF(TString(GetName()) + " MVA PDF Sig");
01466 fMVAPdfB = new PDF(TString(GetName()) + " MVA PDF Bkg");
01467 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
01468 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
01469
01470 fin >> *fMVAPdfS;
01471 fin >> *fMVAPdfB;
01472 }
01473
01474
01475 fin.getline(buf,512);
01476 while (!TString(buf).BeginsWith("#WGT")) fin.getline(buf,512);
01477 fin.getline(buf,512);
01478 ReadWeightsFromStream( fin );;
01479
01480
01481 if (GetTransformationHandler().GetCallerName() == "") GetTransformationHandler().SetCallerName( GetName() );
01482
01483 }
01484
01485
01486 void TMVA::MethodBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const
01487 {
01488
01489
01490 o << prefix << "NVar " << DataInfo().GetNVariables() << endl;
01491 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
01492 for (; varIt!=DataInfo().GetVariableInfos().end(); varIt++) { o << prefix; varIt->WriteToStream(o); }
01493 o << prefix << "NSpec " << DataInfo().GetNSpectators() << endl;
01494 varIt = DataInfo().GetSpectatorInfos().begin();
01495 for (; varIt!=DataInfo().GetSpectatorInfos().end(); varIt++) { o << prefix; varIt->WriteToStream(o); }
01496 }
01497
01498
01499 void TMVA::MethodBase::ReadVarsFromStream( std::istream& istr )
01500 {
01501
01502
01503
01504 TString dummy;
01505 UInt_t readNVar;
01506 istr >> dummy >> readNVar;
01507
01508 if (readNVar!=DataInfo().GetNVariables()) {
01509 Log() << kFATAL << "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
01510 << " while there are " << readNVar << " variables declared in the file"
01511 << Endl;
01512 }
01513
01514
01515 VariableInfo varInfo;
01516 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
01517 int varIdx = 0;
01518 for (; varIt!=DataInfo().GetVariableInfos().end(); varIt++, varIdx++) {
01519 varInfo.ReadFromStream(istr);
01520 if (varIt->GetExpression() == varInfo.GetExpression()) {
01521 varInfo.SetExternalLink((*varIt).GetExternalLink());
01522 (*varIt) = varInfo;
01523 }
01524 else {
01525 Log() << kINFO << "ERROR in <ReadVarsFromStream>" << Endl;
01526 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
01527 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
01528 Log() << kINFO << "the correct working of the method):" << Endl;
01529 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
01530 Log() << kINFO << " var #" << varIdx <<" declared in file : " << varInfo.GetExpression() << Endl;
01531 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
01532 }
01533 }
01534 }
01535
01536
01537 void TMVA::MethodBase::AddVarsXMLTo( void* parent ) const
01538 {
01539
01540 void* vars = gTools().AddChild(parent, "Variables");
01541 gTools().AddAttr( vars, "NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
01542
01543 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
01544 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
01545 void* var = gTools().AddChild( vars, "Variable" );
01546 gTools().AddAttr( var, "VarIndex", idx );
01547 vi.AddToXML( var );
01548 }
01549 }
01550
01551
01552 void TMVA::MethodBase::AddSpectatorsXMLTo( void* parent ) const
01553 {
01554
01555 void* specs = gTools().AddChild(parent, "Spectators");
01556
01557 UInt_t writeIdx=0;
01558 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
01559
01560 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
01561
01562
01563
01564 if( vi.GetVarType()=='C' )
01565 continue;
01566
01567 void* spec = gTools().AddChild( specs, "Spectator" );
01568 gTools().AddAttr( spec, "SpecIndex", writeIdx++ );
01569 vi.AddToXML( spec );
01570 }
01571 gTools().AddAttr( specs, "NSpec", gTools().StringFromInt(writeIdx) );
01572 }
01573
01574
01575 void TMVA::MethodBase::AddClassesXMLTo( void* parent ) const
01576 {
01577
01578 void* targets = gTools().AddChild(parent, "Classes");
01579 gTools().AddAttr( targets, "NClass", gTools().StringFromInt(DataInfo().GetNClasses()) );
01580
01581 }
01582
01583 void TMVA::MethodBase::AddTargetsXMLTo( void* parent ) const
01584 {
01585
01586 void* targets = gTools().AddChild(parent, "Targets");
01587 gTools().AddAttr( targets, "NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
01588
01589 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
01590 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
01591 void* tar = gTools().AddChild( targets, "Target" );
01592 gTools().AddAttr( tar, "TargetIndex", idx );
01593 vi.AddToXML( tar );
01594 }
01595 }
01596
01597
01598 void TMVA::MethodBase::ReadVariablesFromXML( void* varnode )
01599 {
01600
01601 UInt_t readNVar;
01602 gTools().ReadAttr( varnode, "NVar", readNVar);
01603
01604 if (readNVar!=DataInfo().GetNVariables()) {
01605 Log() << kFATAL << "You declared "<< DataInfo().GetNVariables() << " variables in the Reader"
01606 << " while there are " << readNVar << " variables declared in the file"
01607 << Endl;
01608 }
01609
01610
01611 VariableInfo readVarInfo, existingVarInfo;
01612 int varIdx = 0;
01613 void* ch = gTools().GetChild(varnode);
01614 while (ch) {
01615 gTools().ReadAttr( ch, "VarIndex", varIdx);
01616 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
01617 readVarInfo.ReadFromXML(ch);
01618
01619 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
01620 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
01621 existingVarInfo = readVarInfo;
01622 }
01623 else {
01624 Log() << kINFO << "ERROR in <ReadVariablesFromXML>" << Endl;
01625 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
01626 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
01627 Log() << kINFO << "the correct working of the method):" << Endl;
01628 Log() << kINFO << " var #" << varIdx <<" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
01629 Log() << kINFO << " var #" << varIdx <<" declared in file : " << readVarInfo.GetExpression() << Endl;
01630 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
01631 }
01632 ch = gTools().GetNextChild(ch);
01633 }
01634 }
01635
01636
01637 void TMVA::MethodBase::ReadSpectatorsFromXML( void* specnode )
01638 {
01639
01640 UInt_t readNSpec;
01641 gTools().ReadAttr( specnode, "NSpec", readNSpec);
01642
01643 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
01644 Log() << kFATAL << "You declared "<< DataInfo().GetNSpectators(kFALSE) << " spectators in the Reader"
01645 << " while there are " << readNSpec << " spectators declared in the file"
01646 << Endl;
01647 }
01648
01649
01650 VariableInfo readSpecInfo, existingSpecInfo;
01651 int specIdx = 0;
01652 void* ch = gTools().GetChild(specnode);
01653 while (ch) {
01654 gTools().ReadAttr( ch, "SpecIndex", specIdx);
01655 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
01656 readSpecInfo.ReadFromXML(ch);
01657
01658 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
01659 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
01660 existingSpecInfo = readSpecInfo;
01661 }
01662 else {
01663 Log() << kINFO << "ERROR in <ReadVariablesFromXML>" << Endl;
01664 Log() << kINFO << "The definition (or the order) of the variables found in the input file is" << Endl;
01665 Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
01666 Log() << kINFO << "the correct working of the method):" << Endl;
01667 Log() << kINFO << " var #" << specIdx <<" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
01668 Log() << kINFO << " var #" << specIdx <<" declared in file : " << readSpecInfo.GetExpression() << Endl;
01669 Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
01670 }
01671 ch = gTools().GetNextChild(ch);
01672 }
01673 }
01674
01675
01676 void TMVA::MethodBase::ReadClassesFromXML( void* clsnode )
01677 {
01678
01679 UInt_t readNCls;
01680
01681 gTools().ReadAttr( clsnode, "NClass", readNCls);
01682
01683 for(UInt_t icls = 0; icls<readNCls;++icls){
01684 TString classname = Form("class%i",icls);
01685 DataInfo().AddClass(classname);
01686
01687 }
01688 }
01689
01690
01691 void TMVA::MethodBase::ReadTargetsFromXML( void* tarnode )
01692 {
01693
01694 UInt_t readNTar;
01695 gTools().ReadAttr( tarnode, "NTrgt", readNTar);
01696
01697 int tarIdx = 0;
01698 TString expression;
01699 void* ch = gTools().GetChild(tarnode);
01700 while (ch) {
01701 gTools().ReadAttr( ch, "TargetIndex", tarIdx);
01702 gTools().ReadAttr( ch, "Expression", expression);
01703 DataInfo().AddTarget(expression,"","",0,0);
01704
01705 ch = gTools().GetNextChild(ch);
01706 }
01707 }
01708
01709
01710 TDirectory* TMVA::MethodBase::BaseDir() const
01711 {
01712
01713
01714
01715 if (fBaseDir != 0) return fBaseDir;
01716 Log()<<kDEBUG<<" Base Directory for " << GetMethodTypeName() << " not set yet --> check if already there.." <<Endl;
01717
01718 TDirectory* methodDir = MethodBaseDir();
01719 if (methodDir==0)
01720 Log() << kFATAL << "MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
01721
01722 TDirectory* dir = 0;
01723
01724 TString defaultDir = GetMethodName();
01725
01726 TObject* o = methodDir->FindObject(defaultDir);
01727 if (o!=0 && o->InheritsFrom(TDirectory::Class())) dir = (TDirectory*)o;
01728
01729 if (dir != 0) {
01730 Log()<<kDEBUG<<" Base Directory for " << GetMethodName() << " existed, return it.." <<Endl;
01731 return dir;
01732 }
01733
01734 Log()<<kDEBUG<<" Base Directory for " << GetMethodName() << " does not exist yet--> created it" <<Endl;
01735 TDirectory *sdir = methodDir->mkdir(defaultDir);
01736
01737
01738 sdir->cd();
01739 TObjString wfilePath( gSystem->WorkingDirectory() );
01740 TObjString wfileName( GetWeightFileName() );
01741 wfilePath.Write( "TrainingPath" );
01742 wfileName.Write( "WeightFileName" );
01743
01744 return sdir;
01745 }
01746
01747
01748 TDirectory* TMVA::MethodBase::MethodBaseDir() const
01749 {
01750
01751
01752
01753 if (fMethodBaseDir != 0) return fMethodBaseDir;
01754
01755 Log()<<kDEBUG<<" Base Directory for " << GetMethodTypeName() << " not set yet --> check if already there.." <<Endl;
01756
01757 const TString dirName(Form("Method_%s",GetMethodTypeName().Data()));
01758
01759 TDirectory * dir = Factory::RootBaseDir()->GetDirectory(dirName);
01760 if (dir != 0){
01761 Log()<<kDEBUG<<" Base Directory for " << GetMethodTypeName() << " existed, return it.." <<Endl;
01762 return dir;
01763 }
01764
01765 Log()<<kDEBUG<<" Base Directory for " << GetMethodTypeName() << " does not exist yet--> created it" <<Endl;
01766 fMethodBaseDir = Factory::RootBaseDir()->mkdir(dirName,Form("Directory for all %s methods", GetMethodTypeName().Data()));
01767
01768 Log()<<kDEBUG<<"Return from MethodBaseDir() after creating base directory "<<Endl;
01769 return fMethodBaseDir;
01770 }
01771
01772
01773 void TMVA::MethodBase::SetWeightFileDir( TString fileDir )
01774 {
01775
01776
01777 fFileDir = fileDir;
01778 gSystem->MakeDirectory( fFileDir );
01779 }
01780
01781
01782 void TMVA::MethodBase::SetWeightFileName( TString theWeightFile)
01783 {
01784
01785 fWeightFile = theWeightFile;
01786 }
01787
01788
01789 TString TMVA::MethodBase::GetWeightFileName() const
01790 {
01791
01792 if (fWeightFile!="") return fWeightFile;
01793
01794
01795
01796 TString suffix = "";
01797 TString wFileDir(GetWeightFileDir());
01798 return ( wFileDir + (wFileDir[wFileDir.Length()-1]=='/' ? "" : "/")
01799 + GetJobName() + "_" + GetMethodName() +
01800 suffix + "." + gConfig().GetIONames().fWeightFileExtension + ".xml" );
01801 }
01802
01803
01804 void TMVA::MethodBase::WriteEvaluationHistosToFile(Types::ETreeType treetype)
01805 {
01806
01807 BaseDir()->cd();
01808
01809
01810 if (0 != fMVAPdfS) {
01811 fMVAPdfS->GetOriginalHist()->Write();
01812 fMVAPdfS->GetSmoothedHist()->Write();
01813 fMVAPdfS->GetPDFHist()->Write();
01814 }
01815 if (0 != fMVAPdfB) {
01816 fMVAPdfB->GetOriginalHist()->Write();
01817 fMVAPdfB->GetSmoothedHist()->Write();
01818 fMVAPdfB->GetPDFHist()->Write();
01819 }
01820
01821
01822 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
01823 if (!results)
01824 Log() << kFATAL << "<WriteEvaluationHistosToFile> Unknown result: "
01825 << GetMethodName() << (treetype==Types::kTraining?"/kTraining":"/kTesting") << "/kMaxAnalysisType" << Endl;
01826 results->GetStorage()->Write();
01827 if(treetype==Types::kTesting)
01828 GetTransformationHandler().PlotVariables( GetEventCollection( Types::kTesting ), BaseDir() );
01829 }
01830
01831
01832 void TMVA::MethodBase::WriteMonitoringHistosToFile( void ) const
01833 {
01834
01835
01836 }
01837
01838
01839 Bool_t TMVA::MethodBase::GetLine(std::istream& fin, char* buf )
01840 {
01841
01842
01843
01844 fin.getline(buf,512);
01845 TString line(buf);
01846 if (line.BeginsWith("TMVA Release")) {
01847 Ssiz_t start = line.First('[')+1;
01848 Ssiz_t length = line.Index("]",start)-start;
01849 TString code = line(start,length);
01850 std::stringstream s(code.Data());
01851 s >> fTMVATrainingVersion;
01852 Log() << kINFO << "MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
01853 }
01854 if (line.BeginsWith("ROOT Release")) {
01855 Ssiz_t start = line.First('[')+1;
01856 Ssiz_t length = line.Index("]",start)-start;
01857 TString code = line(start,length);
01858 std::stringstream s(code.Data());
01859 s >> fROOTTrainingVersion;
01860 Log() << kINFO << "MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
01861 }
01862 if (line.BeginsWith("Analysis type")) {
01863 Ssiz_t start = line.First('[')+1;
01864 Ssiz_t length = line.Index("]",start)-start;
01865 TString code = line(start,length);
01866 std::stringstream s(code.Data());
01867 std::string analysisType;
01868 s >> analysisType;
01869 if (analysisType == "regression" || analysisType == "Regression") SetAnalysisType( Types::kRegression );
01870 else if (analysisType == "classification" || analysisType == "Classification") SetAnalysisType( Types::kClassification );
01871 else if (analysisType == "multiclass" || analysisType == "Multiclass") SetAnalysisType( Types::kMulticlass );
01872 else Log() << kFATAL << "Analysis type " << analysisType << " from weight-file not known!" << std::endl;
01873
01874 Log() << kINFO << "Method was trained for "
01875 << (GetAnalysisType() == Types::kRegression ? "Regression" :
01876 (GetAnalysisType() == Types::kMulticlass ? "Multiclass" : "Classification")) << Endl;
01877 }
01878
01879 return true;
01880 }
01881
01882
01883 void TMVA::MethodBase::CreateMVAPdfs()
01884 {
01885
01886
01887 Data()->SetCurrentType(Types::kTraining);
01888
01889 ResultsClassification * mvaRes = dynamic_cast<ResultsClassification*>
01890 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
01891
01892 if (mvaRes==0 || mvaRes->GetSize()==0) {
01893 Log() << kFATAL << "<CreateMVAPdfs> No result of classifier testing available" << Endl;
01894 }
01895
01896 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
01897 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
01898
01899
01900 TH1* histMVAPdfS = new TH1F( GetMethodTypeName() + "_tr_S", GetMethodTypeName() + "_tr_S",
01901 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
01902 TH1* histMVAPdfB = new TH1F( GetMethodTypeName() + "_tr_B", GetMethodTypeName() + "_tr_B",
01903 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
01904
01905
01906
01907 histMVAPdfS->Sumw2();
01908 histMVAPdfB->Sumw2();
01909
01910
01911 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
01912 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
01913 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
01914
01915 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
01916 else histMVAPdfB->Fill( theVal, theWeight );
01917 }
01918
01919 gTools().NormHist( histMVAPdfS );
01920 gTools().NormHist( histMVAPdfB );
01921
01922
01923 histMVAPdfS->Write();
01924 histMVAPdfB->Write();
01925
01926
01927 fMVAPdfS->BuildPDF ( histMVAPdfS );
01928 fMVAPdfB->BuildPDF ( histMVAPdfB );
01929 fMVAPdfS->ValidatePDF( histMVAPdfS );
01930 fMVAPdfB->ValidatePDF( histMVAPdfB );
01931
01932 if (DataInfo().GetNClasses() == 2) {
01933 Log() << kINFO
01934 << Form( "<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
01935 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
01936 << Endl;
01937 }
01938
01939 delete histMVAPdfS;
01940 delete histMVAPdfB;
01941 }
01942
01943
01944 Double_t TMVA::MethodBase::GetProba( Double_t mvaVal, Double_t ap_sig )
01945 {
01946
01947 if (!fMVAPdfS || !fMVAPdfB) {
01948 Log() << kWARNING << "<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
01949 return -1.0;
01950 }
01951 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
01952 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
01953
01954 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
01955
01956 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
01957 }
01958
01959
01960 Double_t TMVA::MethodBase::GetRarity( Double_t mvaVal, Types::ESBType reftype ) const
01961 {
01962
01963
01964
01965
01966 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
01967 Log() << kWARNING << "<GetRarity> Required MVA PDF for Signal or Backgroud does not exist: "
01968 << "select option \"CreateMVAPdfs\"" << Endl;
01969 return 0.0;
01970 }
01971
01972 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
01973
01974 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
01975 }
01976
01977
01978 Double_t TMVA::MethodBase::GetEfficiency( const TString& theString, Types::ETreeType type,Double_t& effSerr )
01979 {
01980
01981
01982
01983 Data()->SetCurrentType(type);
01984 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
01985 std::vector<Float_t>* mvaRes = dynamic_cast<ResultsClassification*>(results)->GetValueVector();
01986
01987
01988 TList* list = gTools().ParseFormatLine( theString );
01989
01990
01991 Bool_t computeArea = kFALSE;
01992 if (!list || list->GetSize() < 2) computeArea = kTRUE;
01993 else if (list->GetSize() > 2) {
01994 Log() << kFATAL << "<GetEfficiency> Wrong number of arguments"
01995 << " in string: " << theString
01996 << " | required format, e.g., Efficiency:0.05, or empty string" << Endl;
01997 delete list;
01998 return -1;
01999 }
02000
02001
02002 if ( results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
02003 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
02004 Log() << kFATAL << "<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
02005 delete list;
02006 return -1.0;
02007 }
02008
02009
02010
02011
02012 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
02013 Double_t xmin = effhist->GetXaxis()->GetXmin();
02014 Double_t xmax = effhist->GetXaxis()->GetXmax();
02015
02016 static Double_t nevtS;
02017
02018
02019 if (results->GetHist("MVA_EFF_S")==0) {
02020
02021
02022 TH1* eff_s = new TH1F( GetTestvarName() + "_effS", GetTestvarName() + " (signal)", fNbinsH, xmin, xmax );
02023 TH1* eff_b = new TH1F( GetTestvarName() + "_effB", GetTestvarName() + " (background)", fNbinsH, xmin, xmax );
02024 results->Store(eff_s, "MVA_EFF_S");
02025 results->Store(eff_b, "MVA_EFF_B");
02026
02027
02028 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
02029
02030
02031 nevtS = 0;
02032 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
02033
02034
02035 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
02036 Float_t theWeight = GetEvent(ievt)->GetWeight();
02037 Float_t theVal = (*mvaRes)[ievt];
02038
02039
02040 TH1* theHist = isSignal ? eff_s : eff_b;
02041
02042
02043 if (isSignal) nevtS+=theWeight;
02044
02045 TAxis* axis = theHist->GetXaxis();
02046 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
02047 if (sign > 0 && maxbin > fNbinsH) continue;
02048 if (sign < 0 && maxbin < 1 ) continue;
02049 if (sign > 0 && maxbin < 1 ) maxbin = 1;
02050 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
02051
02052 if (sign > 0)
02053 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
02054 else if (sign < 0)
02055 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
02056 else
02057 Log() << kFATAL << "<GetEfficiency> Mismatch in sign" << Endl;
02058 }
02059
02060
02061 eff_s->Scale( 1.0/TMath::Max(1.,eff_s->GetMaximum()) );
02062 eff_b->Scale( 1.0/TMath::Max(1.,eff_b->GetMaximum()) );
02063
02064
02065 TH1* eff_BvsS = new TH1F( GetTestvarName() + "_effBvsS", GetTestvarName() + "", fNbins, 0, 1 );
02066 results->Store(eff_BvsS, "MVA_EFF_BvsS");
02067 eff_BvsS->SetXTitle( "Signal eff" );
02068 eff_BvsS->SetYTitle( "Backgr eff" );
02069
02070
02071 TH1* rej_BvsS = new TH1F( GetTestvarName() + "_rejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
02072 results->Store(rej_BvsS);
02073 rej_BvsS->SetXTitle( "Signal eff" );
02074 rej_BvsS->SetYTitle( "Backgr rejection (1-eff)" );
02075
02076
02077 TH1* inveff_BvsS = new TH1F( GetTestvarName() + "_invBeffvsSeff",
02078 GetTestvarName(), fNbins, 0, 1 );
02079 results->Store(inveff_BvsS);
02080 inveff_BvsS->SetXTitle( "Signal eff" );
02081 inveff_BvsS->SetYTitle( "Inverse backgr. eff (1/eff)" );
02082
02083
02084
02085
02086 if (Use_Splines_for_Eff_) {
02087 fSplRefS = new TSpline1( "spline2_signal", new TGraph( eff_s ) );
02088 fSplRefB = new TSpline1( "spline2_background", new TGraph( eff_b ) );
02089
02090
02091 gTools().CheckSplines( eff_s, fSplRefS );
02092 gTools().CheckSplines( eff_b, fSplRefB );
02093 }
02094
02095
02096
02097
02098
02099 ResetThisBase();
02100 RootFinder rootFinder( &IGetEffForRoot, fXmin, fXmax );
02101
02102 Double_t effB = 0;
02103 fEffS = eff_s;
02104 for (Int_t bini=1; bini<=fNbins; bini++) {
02105
02106
02107 Double_t effS = eff_BvsS->GetBinCenter( bini );
02108 Double_t cut = rootFinder.Root( effS );
02109
02110
02111 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
02112 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
02113
02114
02115 eff_BvsS->SetBinContent( bini, effB );
02116 rej_BvsS->SetBinContent( bini, 1.0-effB );
02117 if (effB>std::numeric_limits<double>::epsilon())
02118 inveff_BvsS->SetBinContent( bini, 1.0/effB );
02119 }
02120
02121
02122 fSpleffBvsS = new TSpline1( "effBvsS", new TGraph( eff_BvsS ) );
02123
02124
02125
02126 Double_t effS, rejB, effS_ = 0, rejB_ = 0;
02127 Int_t nbins_ = 5000;
02128 for (Int_t bini=1; bini<=nbins_; bini++) {
02129
02130
02131 effS = (bini - 0.5)/Float_t(nbins_);
02132 rejB = 1.0 - fSpleffBvsS->Eval( effS );
02133
02134
02135 if ((effS - rejB)*(effS_ - rejB_) < 0) break;
02136 effS_ = effS;
02137 rejB_ = rejB;
02138 }
02139
02140
02141 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
02142 SetSignalReferenceCut( cut );
02143 fEffS = 0;
02144 }
02145
02146
02147 if (0 == fSpleffBvsS) {
02148 delete list;
02149 return 0.0;
02150 }
02151
02152
02153 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
02154 Int_t nbins_ = 1000;
02155
02156 if (computeArea) {
02157
02158
02159 Double_t integral = 0;
02160 for (Int_t bini=1; bini<=nbins_; bini++) {
02161
02162
02163 effS = (bini - 0.5)/Float_t(nbins_);
02164 effB = fSpleffBvsS->Eval( effS );
02165 integral += (1.0 - effB);
02166 }
02167 integral /= nbins_;
02168
02169 delete list;
02170 return integral;
02171 }
02172 else {
02173
02174
02175
02176 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
02177
02178
02179 for (Int_t bini=1; bini<=nbins_; bini++) {
02180
02181
02182 effS = (bini - 0.5)/Float_t(nbins_);
02183 effB = fSpleffBvsS->Eval( effS );
02184
02185
02186 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
02187 effS_ = effS;
02188 effB_ = effB;
02189 }
02190
02191
02192 effS = 0.5*(effS + effS_);
02193
02194 effSerr = 0;
02195 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
02196
02197 delete list;
02198 return effS;
02199 }
02200
02201 return -1;
02202 }
02203
02204
02205 Double_t TMVA::MethodBase::GetTrainingEfficiency(const TString& theString)
02206 {
02207 Data()->SetCurrentType(Types::kTraining);
02208
02209 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
02210
02211
02212
02213
02214
02215 TList* list = gTools().ParseFormatLine( theString );
02216
02217
02218 if (list->GetSize() != 2) {
02219 Log() << kFATAL << "<GetTrainingEfficiency> Wrong number of arguments"
02220 << " in string: " << theString
02221 << " | required format, e.g., Efficiency:0.05" << Endl;
02222 delete list;
02223 return -1;
02224 }
02225
02226
02227 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
02228
02229 delete list;
02230
02231
02232 if (results->GetHist("MVA_S")->GetNbinsX() != results->GetHist("MVA_B")->GetNbinsX() ||
02233 results->GetHist("MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist("MVA_HIGHBIN_B")->GetNbinsX() ) {
02234 Log() << kFATAL << "<GetTrainingEfficiency> Binning mismatch between signal and background histos"
02235 << Endl;
02236 return -1.0;
02237 }
02238
02239
02240
02241
02242 TH1 * effhist = results->GetHist("MVA_HIGHBIN_S");
02243 Double_t xmin = effhist->GetXaxis()->GetXmin();
02244 Double_t xmax = effhist->GetXaxis()->GetXmax();
02245
02246
02247 if (results->GetHist("MVA_TRAIN_S")==0) {
02248
02249
02250 Double_t sxmax = fXmax+0.00001;
02251
02252
02253 TH1* mva_s_tr = new TH1F( GetTestvarName() + "_Train_S",GetTestvarName() + "_Train_S", fNbins, fXmin, sxmax );
02254 TH1* mva_b_tr = new TH1F( GetTestvarName() + "_Train_B",GetTestvarName() + "_Train_B", fNbins, fXmin, sxmax );
02255 results->Store(mva_s_tr, "MVA_TRAIN_S");
02256 results->Store(mva_b_tr, "MVA_TRAIN_B");
02257 mva_s_tr->Sumw2();
02258 mva_b_tr->Sumw2();
02259
02260
02261 TH1* mva_eff_tr_s = new TH1F( GetTestvarName() + "_trainingEffS", GetTestvarName() + " (signal)",
02262 fNbinsH, xmin, xmax );
02263 TH1* mva_eff_tr_b = new TH1F( GetTestvarName() + "_trainingEffB", GetTestvarName() + " (background)",
02264 fNbinsH, xmin, xmax );
02265 results->Store(mva_eff_tr_s, "MVA_TRAINEFF_S");
02266 results->Store(mva_eff_tr_b, "MVA_TRAINEFF_B");
02267
02268
02269 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
02270
02271
02272 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
02273
02274 Data()->SetCurrentEvent(ievt);
02275 const Event* ev = GetEvent();
02276
02277 Double_t theVal = GetMvaValue();
02278 Double_t theWeight = ev->GetWeight();
02279
02280 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
02281 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
02282
02283 theClsHist->Fill( theVal, theWeight );
02284
02285 TAxis* axis = theEffHist->GetXaxis();
02286 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
02287 if (sign > 0 && maxbin > fNbinsH) continue;
02288 if (sign < 0 && maxbin < 1 ) continue;
02289 if (sign > 0 && maxbin < 1 ) maxbin = 1;
02290 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
02291
02292 if (sign > 0)
02293 for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
02294 else
02295 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
02296 }
02297
02298
02299 gTools().NormHist( mva_s_tr );
02300 gTools().NormHist( mva_b_tr );
02301
02302
02303 mva_eff_tr_s->Scale( 1.0/TMath::Max(1.0, mva_eff_tr_s->GetMaximum()) );
02304 mva_eff_tr_b->Scale( 1.0/TMath::Max(1.0, mva_eff_tr_b->GetMaximum()) );
02305
02306
02307 TH1* eff_bvss = new TH1F( GetTestvarName() + "_trainingEffBvsS", GetTestvarName() + "", fNbins, 0, 1 );
02308
02309 TH1* rej_bvss = new TH1F( GetTestvarName() + "_trainingRejBvsS", GetTestvarName() + "", fNbins, 0, 1 );
02310 results->Store(eff_bvss, "EFF_BVSS_TR");
02311 results->Store(rej_bvss, "REJ_BVSS_TR");
02312
02313
02314
02315
02316 if (Use_Splines_for_Eff_) {
02317 if (fSplTrainRefS) delete fSplTrainRefS;
02318 if (fSplTrainRefB) delete fSplTrainRefB;
02319 fSplTrainRefS = new TSpline1( "spline2_signal", new TGraph( mva_eff_tr_s ) );
02320 fSplTrainRefB = new TSpline1( "spline2_background", new TGraph( mva_eff_tr_b ) );
02321
02322
02323 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
02324 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
02325 }
02326
02327
02328
02329
02330
02331 ResetThisBase();
02332 RootFinder rootFinder(&IGetEffForRoot, fXmin, fXmax );
02333
02334 Double_t effB = 0;
02335 fEffS = results->GetHist("MVA_TRAINEFF_S");
02336 for (Int_t bini=1; bini<=fNbins; bini++) {
02337
02338
02339 Double_t effS = eff_bvss->GetBinCenter( bini );
02340
02341 Double_t cut = rootFinder.Root( effS );
02342
02343
02344 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
02345 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
02346
02347
02348 eff_bvss->SetBinContent( bini, effB );
02349 rej_bvss->SetBinContent( bini, 1.0-effB );
02350 }
02351 fEffS = 0;
02352
02353
02354 fSplTrainEffBvsS = new TSpline1( "effBvsS", new TGraph( eff_bvss ) );
02355 }
02356
02357
02358 if (0 == fSplTrainEffBvsS) return 0.0;
02359
02360
02361 Double_t effS, effB, effS_ = 0, effB_ = 0;
02362 Int_t nbins_ = 1000;
02363 for (Int_t bini=1; bini<=nbins_; bini++) {
02364
02365
02366 effS = (bini - 0.5)/Float_t(nbins_);
02367 effB = fSplTrainEffBvsS->Eval( effS );
02368
02369
02370 if ((effB - effBref)*(effB_ - effBref) <= 0) break;
02371 effS_ = effS;
02372 effB_ = effB;
02373 }
02374
02375 return 0.5*(effS + effS_);
02376 }
02377
02378
02379
02380
02381 std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
02382 {
02383 Data()->SetCurrentType(Types::kTesting);
02384 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
02385 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
02386
02387 purity.push_back(resMulticlass->GetAchievablePur());
02388 return resMulticlass->GetAchievableEff();
02389 }
02390
02391
02392
02393 std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
02394 {
02395 Data()->SetCurrentType(Types::kTraining);
02396 ResultsMulticlass* resMulticlass = dynamic_cast<ResultsMulticlass*>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
02397 if (!resMulticlass) Log() << kFATAL<< "unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
02398
02399 Log() << kINFO << "Determine optimal multiclass cuts for training data..." << Endl;
02400 for(UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls){
02401 resMulticlass->GetBestMultiClassCuts(icls);
02402 }
02403
02404 purity.push_back(resMulticlass->GetAchievablePur());
02405 return resMulticlass->GetAchievableEff();
02406 }
02407
02408
02409
02410 Double_t TMVA::MethodBase::GetSignificance( void ) const
02411 {
02412
02413
02414 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
02415
02416 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
02417 }
02418
02419
02420 Double_t TMVA::MethodBase::GetSeparation( TH1* histoS, TH1* histoB ) const
02421 {
02422
02423
02424 return gTools().GetSeparation( histoS, histoB );
02425 }
02426
02427
02428 Double_t TMVA::MethodBase::GetSeparation( PDF* pdfS, PDF* pdfB ) const
02429 {
02430
02431
02432
02433
02434
02435 if ((!pdfS && pdfB) || (pdfS && !pdfB))
02436 Log() << kFATAL << "<GetSeparation> Mismatch in pdfs" << Endl;
02437 if (!pdfS) pdfS = fSplS;
02438 if (!pdfB) pdfB = fSplB;
02439
02440 if (!fSplS || !fSplB){
02441 Log()<<kWARNING<< "could not calculate the separation, distributions"
02442 << " fSplS or fSplB are not yet filled" << Endl;
02443 return 0;
02444 }else{
02445 return gTools().GetSeparation( *pdfS, *pdfB );
02446 }
02447 }
02448
02449
02450 Double_t TMVA::MethodBase::GetROCIntegral(PDF *pdfS, PDF *pdfB) const
02451 {
02452
02453
02454
02455
02456
02457 if ((!pdfS && pdfB) || (pdfS && !pdfB))
02458 Log() << kFATAL << "<GetSeparation> Mismatch in pdfs" << Endl;
02459 if (!pdfS) pdfS = fSplS;
02460 if (!pdfB) pdfB = fSplB;
02461
02462 if(pdfS==0 || pdfB==0) return 0.;
02463
02464 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
02465 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
02466
02467 Double_t integral = 0;
02468 UInt_t nsteps = 1000;
02469 Double_t step = (xmax-xmin)/Double_t(nsteps);
02470 Double_t cut = xmin;
02471 for (UInt_t i=0; i<nsteps; i++){
02472 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
02473 cut+=step;
02474 }
02475 return integral*step;
02476 }
02477
02478
02479 Double_t TMVA::MethodBase::GetMaximumSignificance( Double_t SignalEvents,
02480 Double_t BackgroundEvents,
02481 Double_t& max_significance_value ) const
02482 {
02483
02484
02485
02486
02487 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
02488
02489 Double_t max_significance(0);
02490 Double_t effS(0),effB(0),significance(0);
02491 TH1F *temp_histogram = new TH1F("temp", "temp", fNbinsH, fXmin, fXmax );
02492
02493 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
02494 Log() << kFATAL << "<GetMaximumSignificance> "
02495 << "Number of signal or background events is <= 0 ==> abort"
02496 << Endl;
02497 }
02498
02499 Log() << kINFO << "Using ratio SignalEvents/BackgroundEvents = "
02500 << SignalEvents/BackgroundEvents << Endl;
02501
02502 TH1* eff_s = results->GetHist("MVA_EFF_S");
02503 TH1* eff_b = results->GetHist("MVA_EFF_B");
02504
02505 if ( (eff_s==0) || (eff_b==0) ) {
02506 Log() << kWARNING << "Efficiency histograms empty !" << Endl;
02507 Log() << kWARNING << "no maximum cut found, return 0" << Endl;
02508 return 0;
02509 }
02510
02511 for (Int_t bin=1; bin<=fNbinsH; bin++) {
02512 effS = eff_s->GetBinContent( bin );
02513 effB = eff_b->GetBinContent( bin );
02514
02515
02516 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
02517
02518 temp_histogram->SetBinContent(bin,significance);
02519 }
02520
02521
02522 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
02523 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
02524
02525
02526 delete temp_histogram;
02527
02528 Log() << kINFO << "Optimal cut at : " << max_significance << Endl;
02529 Log() << kINFO << "Maximum significance: " << max_significance_value << Endl;
02530
02531 return max_significance;
02532 }
02533
02534
02535 void TMVA::MethodBase::Statistics( Types::ETreeType treeType, const TString& theVarName,
02536 Double_t& meanS, Double_t& meanB,
02537 Double_t& rmsS, Double_t& rmsB,
02538 Double_t& xmin, Double_t& xmax )
02539 {
02540
02541
02542
02543
02544 Types::ETreeType previousTreeType = Data()->GetCurrentType();
02545 Data()->SetCurrentType(treeType);
02546
02547 Long64_t entries = Data()->GetNEvents();
02548
02549
02550 if (entries <=0)
02551 Log() << kFATAL << "<CalculateEstimator> Wrong tree type: " << treeType << Endl;
02552
02553
02554 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
02555
02556
02557 xmin = +DBL_MAX;
02558 xmax = -DBL_MAX;
02559 Long64_t nEventsS = -1;
02560 Long64_t nEventsB = -1;
02561
02562
02563 meanS = 0;
02564 meanB = 0;
02565 rmsS = 0;
02566 rmsB = 0;
02567 Double_t sumwS = 0, sumwB = 0;
02568
02569
02570 for (Int_t ievt = 0; ievt < entries; ievt++) {
02571
02572 const Event* ev = GetEvent(ievt);
02573
02574 Double_t theVar = ev->GetValue(varIndex);
02575 Double_t weight = ev->GetWeight();
02576
02577 if (DataInfo().IsSignal(ev)) {
02578 sumwS += weight;
02579 meanS += weight*theVar;
02580 rmsS += weight*theVar*theVar;
02581 }
02582 else {
02583 sumwB += weight;
02584 meanB += weight*theVar;
02585 rmsB += weight*theVar*theVar;
02586 }
02587 xmin = TMath::Min( xmin, theVar );
02588 xmax = TMath::Max( xmax, theVar );
02589 }
02590 ++nEventsS;
02591 ++nEventsB;
02592
02593 meanS = meanS/sumwS;
02594 meanB = meanB/sumwB;
02595 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
02596 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
02597
02598 Data()->SetCurrentType(previousTreeType);
02599 }
02600
02601
02602 void TMVA::MethodBase::MakeClass( const TString& theClassFileName ) const
02603 {
02604
02605
02606
02607 TString classFileName = "";
02608 if (theClassFileName == "")
02609 classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class.C";
02610 else
02611 classFileName = theClassFileName;
02612
02613 TString className = TString("Read") + GetMethodName();
02614
02615 TString tfname( classFileName );
02616 Log() << kINFO << "Creating standalone response class: "
02617 << gTools().Color("lightblue") << classFileName << gTools().Color("reset") << Endl;
02618
02619 ofstream fout( classFileName );
02620 if (!fout.good()) {
02621 Log() << kFATAL << "<MakeClass> Unable to open file: " << classFileName << Endl;
02622 }
02623
02624
02625
02626 fout << "// Class: " << className << endl;
02627 fout << "// Automatically generated by MethodBase::MakeClass" << endl << "//" << endl;
02628
02629
02630 fout << endl;
02631 fout << "/* configuration options =====================================================" << endl << endl;
02632 WriteStateToStream( fout );
02633 fout << endl;
02634 fout << "============================================================================ */" << endl;
02635
02636
02637 fout << "" << endl;
02638 fout << "#include <vector>" << endl;
02639 fout << "#include <cmath>" << endl;
02640 fout << "#include <string>" << endl;
02641 fout << "#include <iostream>" << endl;
02642 fout << "" << endl;
02643
02644
02645 this->MakeClassSpecificHeader( fout, className );
02646
02647 fout << "#ifndef IClassifierReader__def" << endl;
02648 fout << "#define IClassifierReader__def" << endl;
02649 fout << endl;
02650 fout << "class IClassifierReader {" << endl;
02651 fout << endl;
02652 fout << " public:" << endl;
02653 fout << endl;
02654 fout << " // constructor" << endl;
02655 fout << " IClassifierReader() : fStatusIsClean( true ) {}" << endl;
02656 fout << " virtual ~IClassifierReader() {}" << endl;
02657 fout << endl;
02658 fout << " // return classifier response" << endl;
02659 fout << " virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << endl;
02660 fout << endl;
02661 fout << " // returns classifier status" << endl;
02662 fout << " bool IsStatusClean() const { return fStatusIsClean; }" << endl;
02663 fout << endl;
02664 fout << " protected:" << endl;
02665 fout << endl;
02666 fout << " bool fStatusIsClean;" << endl;
02667 fout << "};" << endl;
02668 fout << endl;
02669 fout << "#endif" << endl;
02670 fout << endl;
02671 fout << "class " << className << " : public IClassifierReader {" << endl;
02672 fout << endl;
02673 fout << " public:" << endl;
02674 fout << endl;
02675 fout << " // constructor" << endl;
02676 fout << " " << className << "( std::vector<std::string>& theInputVars ) " << endl;
02677 fout << " : IClassifierReader()," << endl;
02678 fout << " fClassName( \"" << className << "\" )," << endl;
02679 fout << " fNvars( " << GetNvar() << " )," << endl;
02680 fout << " fIsNormalised( " << (IsNormalised() ? "true" : "false") << " )" << endl;
02681 fout << " { " << endl;
02682 fout << " // the training input variables" << endl;
02683 fout << " const char* inputVars[] = { ";
02684 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
02685 fout << "\"" << GetOriginalVarName(ivar) << "\"";
02686 if (ivar<GetNvar()-1) fout << ", ";
02687 }
02688 fout << " };" << endl;
02689 fout << endl;
02690 fout << " // sanity checks" << endl;
02691 fout << " if (theInputVars.size() <= 0) {" << endl;
02692 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << endl;
02693 fout << " fStatusIsClean = false;" << endl;
02694 fout << " }" << endl;
02695 fout << endl;
02696 fout << " if (theInputVars.size() != fNvars) {" << endl;
02697 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << endl;
02698 fout << " << theInputVars.size() << \" != \" << fNvars << std::endl;" << endl;
02699 fout << " fStatusIsClean = false;" << endl;
02700 fout << " }" << endl;
02701 fout << endl;
02702 fout << " // validate input variables" << endl;
02703 fout << " for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << endl;
02704 fout << " if (theInputVars[ivar] != inputVars[ivar]) {" << endl;
02705 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << endl;
02706 fout << " << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << endl;
02707 fout << " fStatusIsClean = false;" << endl;
02708 fout << " }" << endl;
02709 fout << " }" << endl;
02710 fout << endl;
02711 fout << " // initialize min and max vectors (for normalisation)" << endl;
02712 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
02713 fout << " fVmin[" << ivar << "] = " << std::setprecision(15) << GetXmin( ivar ) << ";" << endl;
02714 fout << " fVmax[" << ivar << "] = " << std::setprecision(15) << GetXmax( ivar ) << ";" << endl;
02715 }
02716 fout << endl;
02717 fout << " // initialize input variable types" << endl;
02718 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
02719 fout << " fType[" << ivar << "] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() << "\';" << endl;
02720 }
02721 fout << endl;
02722 fout << " // initialize constants" << endl;
02723 fout << " Initialize();" << endl;
02724 fout << endl;
02725 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
02726 fout << " // initialize transformation" << endl;
02727 fout << " InitTransform();" << endl;
02728 }
02729 fout << " }" << endl;
02730 fout << endl;
02731 fout << " // destructor" << endl;
02732 fout << " virtual ~" << className << "() {" << endl;
02733 fout << " Clear(); // method-specific" << endl;
02734 fout << " }" << endl;
02735 fout << endl;
02736 fout << " // the classifier response" << endl;
02737 fout << " // \"inputValues\" is a vector of input values in the same order as the " << endl;
02738 fout << " // variables given to the constructor" << endl;
02739 fout << " double GetMvaValue( const std::vector<double>& inputValues ) const;" << endl;
02740 fout << endl;
02741 fout << " private:" << endl;
02742 fout << endl;
02743 fout << " // method-specific destructor" << endl;
02744 fout << " void Clear();" << endl;
02745 fout << endl;
02746 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
02747 fout << " // input variable transformation" << endl;
02748 GetTransformationHandler().MakeFunction(fout, className,1);
02749 fout << " void InitTransform();" << endl;
02750 fout << " void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << endl;
02751 fout << endl;
02752 }
02753 fout << " // common member variables" << endl;
02754 fout << " const char* fClassName;" << endl;
02755 fout << endl;
02756 fout << " const size_t fNvars;" << endl;
02757 fout << " size_t GetNvar() const { return fNvars; }" << endl;
02758 fout << " char GetType( int ivar ) const { return fType[ivar]; }" << endl;
02759 fout << endl;
02760 fout << " // normalisation of input variables" << endl;
02761 fout << " const bool fIsNormalised;" << endl;
02762 fout << " bool IsNormalised() const { return fIsNormalised; }" << endl;
02763 fout << " double fVmin[" << GetNvar() << "];" << endl;
02764 fout << " double fVmax[" << GetNvar() << "];" << endl;
02765 fout << " double NormVariable( double x, double xmin, double xmax ) const {" << endl;
02766 fout << " // normalise to output range: [-1, 1]" << endl;
02767 fout << " return 2*(x - xmin)/(xmax - xmin) - 1.0;" << endl;
02768 fout << " }" << endl;
02769 fout << endl;
02770 fout << " // type of input variable: 'F' or 'I'" << endl;
02771 fout << " char fType[" << GetNvar() << "];" << endl;
02772 fout << endl;
02773 fout << " // initialize internal variables" << endl;
02774 fout << " void Initialize();" << endl;
02775 fout << " double GetMvaValue__( const std::vector<double>& inputValues ) const;" << endl;
02776 fout << "" << endl;
02777 fout << " // private members (method specific)" << endl;
02778
02779
02780 MakeClassSpecific( fout, className );
02781
02782 fout << " inline double " << className << "::GetMvaValue( const std::vector<double>& inputValues ) const" << endl;
02783 fout << " {" << endl;
02784 fout << " // classifier response value" << endl;
02785 fout << " double retval = 0;" << endl;
02786 fout << endl;
02787 fout << " // classifier response, sanity check first" << endl;
02788 fout << " if (!IsStatusClean()) {" << endl;
02789 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << endl;
02790 fout << " << \" because status is dirty\" << std::endl;" << endl;
02791 fout << " retval = 0;" << endl;
02792 fout << " }" << endl;
02793 fout << " else {" << endl;
02794 fout << " if (IsNormalised()) {" << endl;
02795 fout << " // normalise variables" << endl;
02796 fout << " std::vector<double> iV;" << endl;
02797 fout << " int ivar = 0;" << endl;
02798 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << endl;
02799 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << endl;
02800 fout << " iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << endl;
02801 fout << " }" << endl;
02802 if (GetTransformationHandler().GetTransformationList().GetSize()!=0 && GetMethodType() != Types::kLikelihood)
02803 fout << " Transform( iV, -1 );" << endl;
02804 fout << " retval = GetMvaValue__( iV );" << endl;
02805 fout << " }" << endl;
02806 fout << " else {" << endl;
02807 if (GetTransformationHandler().GetTransformationList().GetSize()!=0 && GetMethodType() != Types::kLikelihood) {
02808 fout << " std::vector<double> iV;" << endl;
02809 fout << " int ivar = 0;" << endl;
02810 fout << " for (std::vector<double>::const_iterator varIt = inputValues.begin();" << endl;
02811 fout << " varIt != inputValues.end(); varIt++, ivar++) {" << endl;
02812 fout << " iV.push_back(*varIt);" << endl;
02813 fout << " }" << endl;
02814 fout << " Transform( iV, -1 );" << endl;
02815 fout << " retval = GetMvaValue__( iV );" << endl;
02816 }
02817 else {
02818 fout << " retval = GetMvaValue__( inputValues );" << endl;
02819 }
02820 fout << " }" << endl;
02821 fout << " }" << endl;
02822 fout << endl;
02823 fout << " return retval;" << endl;
02824 fout << " }" << endl;
02825
02826
02827 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
02828 GetTransformationHandler().MakeFunction(fout, className,2);
02829
02830
02831 fout.close();
02832 }
02833
02834
02835 void TMVA::MethodBase::PrintHelpMessage() const
02836 {
02837
02838
02839
02840 std::streambuf* cout_sbuf = std::cout.rdbuf();
02841 std::ofstream* o = 0;
02842 if (gConfig().WriteOptionsReference()) {
02843 Log() << kINFO << "Print Help message for class " << GetName() << " into file: " << GetReferenceFile() << Endl;
02844 o = new std::ofstream( GetReferenceFile(), std::ios::app );
02845 if (!o->good()) {
02846 Log() << kFATAL << "<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
02847 }
02848 std::cout.rdbuf( o->rdbuf() );
02849 }
02850
02851
02852 if (!o) {
02853 Log() << kINFO << Endl;
02854 Log() << gTools().Color("bold")
02855 << "================================================================"
02856 << gTools().Color( "reset" )
02857 << Endl;
02858 Log() << gTools().Color("bold")
02859 << "H e l p f o r M V A m e t h o d [ " << GetName() << " ] :"
02860 << gTools().Color( "reset" )
02861 << Endl;
02862 }
02863 else {
02864 Log() << "Help for MVA method [ " << GetName() << " ] :" << Endl;
02865 }
02866
02867
02868 GetHelpMessage();
02869
02870 if (!o) {
02871 Log() << Endl;
02872 Log() << "<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
02873 Log() << gTools().Color("bold")
02874 << "================================================================"
02875 << gTools().Color( "reset" )
02876 << Endl;
02877 Log() << Endl;
02878 }
02879 else {
02880
02881 Log() << "# End of Message___" << Endl;
02882 }
02883
02884 std::cout.rdbuf( cout_sbuf );
02885 if (o) o->close();
02886 }
02887
02888
02889
02890 TMVA::MethodBase* TMVA::MethodBase::fgThisBase = 0;
02891
02892
02893 Double_t TMVA::MethodBase::IGetEffForRoot( Double_t theCut )
02894 {
02895
02896 return MethodBase::GetThisBase()->GetEffForRoot( theCut );
02897 }
02898
02899
02900 Double_t TMVA::MethodBase::GetEffForRoot( Double_t theCut )
02901 {
02902
02903 Double_t retval=0;
02904
02905
02906 if (Use_Splines_for_Eff_) {
02907 retval = fSplRefS->Eval( theCut );
02908 }
02909 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
02910
02911
02912
02913
02914
02915
02916
02917 Double_t eps = 1.0e-5;
02918 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
02919 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
02920
02921 return retval;
02922 }
02923
02924
02925 const std::vector<TMVA::Event*>& TMVA::MethodBase::GetEventCollection( Types::ETreeType type)
02926 {
02927 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
02928 return (Data()->GetEventCollection(type));
02929 }
02930 Int_t idx = Data()->TreeIndex(type);
02931 if (fEventCollections.at(idx) == 0) {
02932 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
02933 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
02934 }
02935 return *(fEventCollections.at(idx));
02936 }
02937
02938
02939 TString TMVA::MethodBase::GetTrainingTMVAVersionString() const
02940 {
02941
02942 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
02943 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
02944 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
02945
02946 return TString(Form("%i.%i.%i",a,b,c));
02947 }
02948
02949
02950 TString TMVA::MethodBase::GetTrainingROOTVersionString() const
02951 {
02952
02953 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
02954 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
02955 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
02956
02957 return TString(Form("%i.%02i/%02i",a,b,c));
02958 }
02959
02960
02961 TMVA::MethodBase* TMVA::MethodBase::GetThisBase()
02962 {
02963
02964 return fgThisBase;
02965 }
02966
02967
02968 void TMVA::MethodBase::ResetThisBase()
02969 {
02970
02971 fgThisBase = this;
02972 }