00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089 #include "TMVA/Reader.h"
00090
00091 #include "TTree.h"
00092 #include "TLeaf.h"
00093 #include "TString.h"
00094 #include "TClass.h"
00095 #include "TH1D.h"
00096 #include "TKey.h"
00097 #include "TVector.h"
00098 #include "TXMLEngine.h"
00099
00100 #include <cstdlib>
00101
00102 #include <string>
00103 #include <vector>
00104 #include <fstream>
00105
00106 #include <iostream>
00107 #ifndef ROOT_TMVA_Tools
00108 #include "TMVA/Tools.h"
00109 #endif
00110 #include "TMVA/Config.h"
00111 #include "TMVA/ClassifierFactory.h"
00112 #include "TMVA/IMethod.h"
00113 #include "TMVA/MethodCuts.h"
00114 #include "TMVA/MethodCategory.h"
00115 #include "TMVA/DataSetManager.h"
00116
00117 ClassImp(TMVA::Reader)
00118
00119
00120 TMVA::Reader::Reader( const TString& theOption, Bool_t verbose )
00121 : Configurable( theOption ),
00122 fDataSetManager( NULL ),
00123 fDataSetInfo(),
00124 fVerbose( verbose ),
00125 fSilent ( kFALSE ),
00126 fColor ( kFALSE ),
00127 fCalculateError(kFALSE),
00128 fMvaEventError( 0 ),
00129 fMvaEventErrorUpper( 0 ),
00130 fLogger ( 0 )
00131 {
00132
00133 fDataSetManager = new DataSetManager( fDataInputHandler );
00134 fDataSetManager->AddDataSetInfo(fDataSetInfo);
00135 fLogger = new MsgLogger(this);
00136 SetConfigName( GetName() );
00137 DeclareOptions();
00138 ParseOptions();
00139
00140 Init();
00141 }
00142
00143
00144 TMVA::Reader::Reader( std::vector<TString>& inputVars, const TString& theOption, Bool_t verbose )
00145 : Configurable( theOption ),
00146 fDataSetManager( NULL ),
00147 fDataSetInfo(),
00148 fVerbose( verbose ),
00149 fSilent ( kFALSE ),
00150 fColor ( kFALSE ),
00151 fCalculateError(kFALSE),
00152 fMvaEventError( 0 ),
00153 fMvaEventErrorUpper( 0 ),
00154 fLogger ( 0 )
00155 {
00156
00157
00158 fDataSetManager = new DataSetManager( fDataInputHandler );
00159 fDataSetManager->AddDataSetInfo(fDataSetInfo);
00160 fLogger = new MsgLogger(this);
00161 SetConfigName( GetName() );
00162 DeclareOptions();
00163 ParseOptions();
00164
00165
00166
00167 for (std::vector<TString>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++) {
00168 DataInfo().AddVariable( *ivar );
00169 }
00170
00171 Init();
00172 }
00173
00174
00175 TMVA::Reader::Reader( std::vector<std::string>& inputVars, const TString& theOption, Bool_t verbose )
00176 : Configurable( theOption ),
00177 fDataSetManager( NULL ),
00178 fDataSetInfo(),
00179 fVerbose( verbose ),
00180 fSilent ( kFALSE ),
00181 fColor ( kFALSE ),
00182 fCalculateError(kFALSE),
00183 fMvaEventError( 0 ),
00184 fMvaEventErrorUpper( 0 ),
00185 fLogger ( 0 )
00186 {
00187
00188 fDataSetManager = new DataSetManager( fDataInputHandler );
00189 fDataSetManager->AddDataSetInfo(fDataSetInfo);
00190 fLogger = new MsgLogger(this);
00191 SetConfigName( GetName() );
00192 DeclareOptions();
00193 ParseOptions();
00194
00195
00196
00197 for (std::vector<std::string>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ivar++) {
00198 DataInfo().AddVariable( ivar->c_str() );
00199 }
00200
00201 Init();
00202 }
00203
00204
00205 TMVA::Reader::Reader( const std::string& varNames, const TString& theOption, Bool_t verbose )
00206 : Configurable( theOption ),
00207 fDataSetManager( NULL ),
00208 fDataSetInfo(),
00209 fVerbose( verbose ),
00210 fSilent ( kFALSE ),
00211 fColor ( kFALSE ),
00212 fCalculateError(kFALSE),
00213 fMvaEventError( 0 ),
00214 fMvaEventErrorUpper( 0 ),
00215 fLogger ( 0 )
00216 {
00217
00218 fDataSetManager = new DataSetManager( fDataInputHandler );
00219 fDataSetManager->AddDataSetInfo(fDataSetInfo);
00220 fLogger = new MsgLogger(this);
00221 SetConfigName( GetName() );
00222 DeclareOptions();
00223 ParseOptions();
00224
00225
00226
00227 DecodeVarNames(varNames);
00228 Init();
00229 }
00230
00231
00232 TMVA::Reader::Reader( const TString& varNames, const TString& theOption, Bool_t verbose )
00233 : Configurable( theOption ),
00234 fDataSetManager( NULL ),
00235 fDataSetInfo(),
00236 fVerbose( verbose ),
00237 fSilent ( kFALSE ),
00238 fColor ( kFALSE ),
00239 fCalculateError(kFALSE),
00240 fMvaEventError( 0 ),
00241 fMvaEventErrorUpper( 0 ),
00242 fLogger ( 0 )
00243 {
00244
00245 fDataSetManager = new DataSetManager( fDataInputHandler );
00246 fDataSetManager->AddDataSetInfo(fDataSetInfo);
00247 fLogger = new MsgLogger(this);
00248 SetConfigName( GetName() );
00249 DeclareOptions();
00250 ParseOptions();
00251
00252
00253
00254 DecodeVarNames(varNames);
00255 Init();
00256 }
00257
00258
00259 void TMVA::Reader::DeclareOptions()
00260 {
00261
00262 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput();
00263
00264 DeclareOptionRef( fVerbose, "V", "Verbose flag" );
00265 DeclareOptionRef( fColor, "Color", "Color flag (default True)" );
00266 DeclareOptionRef( fSilent, "Silent", "Boolean silent flag (default False)" );
00267 DeclareOptionRef( fCalculateError, "Error", "Calculates errors (default False)" );
00268 }
00269
00270
00271 TMVA::Reader::~Reader( void )
00272 {
00273
00274
00275 delete fDataSetManager;
00276
00277 delete fLogger;
00278 }
00279
00280
00281 void TMVA::Reader::Init( void )
00282 {
00283
00284
00285 if (Verbose()) fLogger->SetMinType( kVERBOSE );
00286
00287 gConfig().SetUseColor( fColor );
00288 gConfig().SetSilent ( fSilent );
00289 }
00290
00291
00292 void TMVA::Reader::AddVariable( const TString& expression, Float_t* datalink )
00293 {
00294
00295 DataInfo().AddVariable( expression, "", "", 0, 0, 'F', kFALSE ,(void*)datalink );
00296 }
00297
00298
00299 void TMVA::Reader::AddVariable( const TString& expression, Int_t* datalink )
00300 {
00301 Log() << kFATAL << "Reader::AddVariable( const TString& expression, Int_t* datalink ), this function is deprecated, please provide all variables to the reader as floats" << Endl;
00302
00303 Log() << kFATAL << "Reader::AddVariable( const TString& expression, Int_t* datalink ), this function is deprecated, please provide all variables to the reader as floats" << Endl;
00304 DataInfo().AddVariable(expression, "", "", 0, 0, 'I', kFALSE, (void*)datalink );
00305 }
00306
00307
00308 void TMVA::Reader::AddSpectator( const TString& expression, Float_t* datalink )
00309 {
00310
00311 DataInfo().AddSpectator( expression, "", "", 0, 0, 'F', kFALSE ,(void*)datalink );
00312 }
00313
00314
00315 void TMVA::Reader::AddSpectator( const TString& expression, Int_t* datalink )
00316 {
00317
00318 DataInfo().AddSpectator(expression, "", "", 0, 0, 'I', kFALSE, (void*)datalink );
00319 }
00320
00321
00322 TString TMVA::Reader::GetMethodTypeFromFile( const TString& filename )
00323 {
00324
00325
00326 ifstream fin( filename );
00327 if (!fin.good()) {
00328 Log() << kFATAL << "<BookMVA> fatal error: "
00329 << "unable to open input weight file: " << filename << Endl;
00330 }
00331
00332 TString fullMethodName("");
00333 if (filename.EndsWith(".xml")) {
00334 fin.close();
00335 void* doc = gTools().xmlengine().ParseFile(filename);
00336 void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
00337 gTools().ReadAttr(rootnode, "Method", fullMethodName);
00338 gTools().xmlengine().FreeDoc(doc);
00339 }
00340 else {
00341 char buf[512];
00342 fin.getline(buf,512);
00343 while (!TString(buf).BeginsWith("Method")) fin.getline(buf,512);
00344 fullMethodName = TString(buf);
00345 fin.close();
00346 }
00347 TString methodType = fullMethodName(0,fullMethodName.Index("::"));
00348 if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
00349 return methodType;
00350 }
00351
00352
00353 TMVA::IMethod* TMVA::Reader::BookMVA( const TString& methodTag, const TString& weightfile )
00354 {
00355
00356
00357
00358 if (fMethodMap.find( methodTag ) != fMethodMap.end())
00359 Log() << kFATAL << "<BookMVA> method tag \"" << methodTag << "\" already exists!" << Endl;
00360
00361 TString methodType(GetMethodTypeFromFile(weightfile));
00362
00363 Log() << kINFO << "Booking \"" << methodTag << "\" of type \"" << methodType << "\" from " << weightfile << "." << Endl;
00364
00365 MethodBase* method = dynamic_cast<MethodBase*>(this->BookMVA( Types::Instance().GetMethodType(methodType),
00366 weightfile ) );
00367 if( method && method->GetMethodType() == Types::kCategory ){
00368 MethodCategory *methCat = (dynamic_cast<MethodCategory*>(method));
00369 if( !methCat )
00370 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
00371 methCat->fDataSetManager = fDataSetManager;
00372 }
00373
00374 return fMethodMap[methodTag] = method;
00375 }
00376
00377
00378 TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType, const TString& weightfile )
00379 {
00380
00381 IMethod* im = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( methodType )),
00382 DataInfo(), weightfile );
00383
00384 MethodBase *method = (dynamic_cast<MethodBase*>(im));
00385
00386 if (method==0) return im;
00387
00388 if( method->GetMethodType() == Types::kCategory ){
00389 MethodCategory *methCat = (dynamic_cast<MethodCategory*>(method));
00390 if( !methCat )
00391 Log() << kERROR << "Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
00392 methCat->fDataSetManager = fDataSetManager;
00393 }
00394
00395 method->SetupMethod();
00396
00397
00398
00399 method->DeclareCompatibilityOptions();
00400
00401
00402 method->ReadStateFromFile();
00403
00404
00405 method->CheckSetup();
00406
00407 Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
00408 << "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
00409
00410 return method;
00411 }
00412
00413
00414 TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType, const char* xmlstr )
00415 {
00416
00417 #if (ROOT_SVN_REVISION >= 32259) && (ROOT_VERSION_CODE >= 334336) // 5.26/00
00418
00419
00420 IMethod* im = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( methodType )),
00421 DataInfo(), "" );
00422
00423 MethodBase *method = (dynamic_cast<MethodBase*>(im));
00424
00425 if(!method) return 0;
00426
00427 if( method->GetMethodType() == Types::kCategory ){
00428 MethodCategory *methCat = (dynamic_cast<MethodCategory*>(method));
00429 if( !methCat )
00430 Log() << kFATAL << "Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
00431 methCat->fDataSetManager = fDataSetManager;
00432 }
00433
00434 method->SetupMethod();
00435
00436
00437
00438 method->DeclareCompatibilityOptions();
00439
00440
00441 method->ReadStateFromXMLString( xmlstr );
00442
00443
00444 method->CheckSetup();
00445
00446 Log() << kINFO << "Booked classifier \"" << method->GetMethodName()
00447 << "\" of type: \"" << method->GetMethodTypeName() << "\"" << Endl;
00448
00449 return method;
00450 #else
00451 Log() << kFATAL << "Method Reader::BookMVA(TMVA::Types::EMVA methodType = " << methodType
00452 << ", const char* xmlstr = " << xmlstr
00453 << " ) is not available for ROOT versions prior to 5.26/00." << Endl;
00454 return 0;
00455 #endif
00456 }
00457
00458
00459 Double_t TMVA::Reader::EvaluateMVA( const std::vector<Float_t>& inputVec, const TString& methodTag, Double_t aux )
00460 {
00461
00462
00463
00464
00465 IMethod* imeth = FindMVA( methodTag );
00466 MethodBase* meth = dynamic_cast<TMVA::MethodBase*>(imeth);
00467 if(meth==0) return 0;
00468 Event* tmpEvent=new Event(inputVec, 2);
00469 if (meth->GetMethodType() == TMVA::Types::kCuts) {
00470 TMVA::MethodCuts* mc = dynamic_cast<TMVA::MethodCuts*>(meth);
00471 if(mc)
00472 mc->SetTestSignalEfficiency( aux );
00473 }
00474 Double_t val = meth->GetMvaValue( tmpEvent, (fCalculateError?&fMvaEventError:0));
00475 delete tmpEvent;
00476 return val;
00477 }
00478
00479
00480 Double_t TMVA::Reader::EvaluateMVA( const std::vector<Double_t>& inputVec, const TString& methodTag, Double_t aux )
00481 {
00482
00483
00484
00485
00486 if(fTmpEvalVec.size() != inputVec.size())
00487 fTmpEvalVec.resize(inputVec.size());
00488
00489 for (UInt_t idx=0; idx!=inputVec.size(); idx++ )
00490 fTmpEvalVec[idx]=inputVec[idx];
00491
00492 return EvaluateMVA( fTmpEvalVec, methodTag, aux );
00493 }
00494
00495
00496 Double_t TMVA::Reader::EvaluateMVA( const TString& methodTag, Double_t aux )
00497 {
00498
00499 IMethod* method = 0;
00500
00501 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
00502 if (it == fMethodMap.end()) {
00503 Log() << kINFO << "<EvaluateMVA> unknown classifier in map; "
00504 << "you looked for \"" << methodTag << "\" within available methods: " << Endl;
00505 for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
00506 Log() << "Check calling string" << kFATAL << Endl;
00507 }
00508
00509 else method = it->second;
00510
00511 MethodBase * kl = dynamic_cast<TMVA::MethodBase*>(method);
00512
00513 if(kl==0)
00514 Log() << kFATAL << methodTag << " is not a method" << Endl;
00515
00516 return this->EvaluateMVA( kl, aux );
00517 }
00518
00519
00520 Double_t TMVA::Reader::EvaluateMVA( MethodBase* method, Double_t aux )
00521 {
00522
00523
00524
00525
00526 if (method->GetMethodType() == TMVA::Types::kCuts) {
00527 TMVA::MethodCuts* mc = dynamic_cast<TMVA::MethodCuts*>(method);
00528 if(mc)
00529 mc->SetTestSignalEfficiency( aux );
00530 }
00531
00532 return method->GetMvaValue( (fCalculateError?&fMvaEventError:0),
00533 (fCalculateError?&fMvaEventErrorUpper:0) );
00534 }
00535
00536
00537 const std::vector< Float_t >& TMVA::Reader::EvaluateRegression( const TString& methodTag, Double_t aux )
00538 {
00539
00540 IMethod* method = 0;
00541
00542 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
00543 if (it == fMethodMap.end()) {
00544 Log() << kINFO << "<EvaluateMVA> unknown method in map; "
00545 << "you looked for \"" << methodTag << "\" within available methods: " << Endl;
00546 for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
00547 Log() << "Check calling string" << kFATAL << Endl;
00548 }
00549 else method = it->second;
00550
00551 MethodBase * kl = dynamic_cast<TMVA::MethodBase*>(method);
00552
00553 if(kl==0)
00554 Log() << kFATAL << methodTag << " is not a method" << Endl;
00555
00556 return this->EvaluateRegression( kl, aux );
00557 }
00558
00559
00560 const std::vector< Float_t >& TMVA::Reader::EvaluateRegression( MethodBase* method, Double_t )
00561 {
00562
00563 return method->GetRegressionValues();
00564 }
00565
00566
00567
00568 Float_t TMVA::Reader::EvaluateRegression( UInt_t tgtNumber, const TString& methodTag, Double_t aux )
00569 {
00570
00571 try {
00572 return EvaluateRegression(methodTag, aux).at(tgtNumber);
00573 }
00574 catch (std::out_of_range e) {
00575 Log() << kWARNING << "Regression could not be evaluated for target-number " << tgtNumber << Endl;
00576 return 0;
00577 }
00578 }
00579
00580
00581
00582
00583 const std::vector< Float_t >& TMVA::Reader::EvaluateMulticlass( const TString& methodTag, Double_t aux )
00584 {
00585
00586 IMethod* method = 0;
00587
00588 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
00589 if (it == fMethodMap.end()) {
00590 Log() << kINFO << "<EvaluateMVA> unknown method in map; "
00591 << "you looked for \"" << methodTag << "\" within available methods: " << Endl;
00592 for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << " --> " << it->first << Endl;
00593 Log() << "Check calling string" << kFATAL << Endl;
00594 }
00595 else method = it->second;
00596
00597 MethodBase * kl = dynamic_cast<TMVA::MethodBase*>(method);
00598
00599 if(kl==0)
00600 Log() << kFATAL << methodTag << " is not a method" << Endl;
00601
00602 return this->EvaluateMulticlass( kl, aux );
00603 }
00604
00605
00606 const std::vector< Float_t >& TMVA::Reader::EvaluateMulticlass( MethodBase* method, Double_t )
00607 {
00608
00609 return method->GetMulticlassValues();
00610 }
00611
00612
00613
00614 Float_t TMVA::Reader::EvaluateMulticlass( UInt_t clsNumber, const TString& methodTag, Double_t aux )
00615 {
00616
00617 try {
00618 return EvaluateMulticlass(methodTag, aux).at(clsNumber);
00619 }
00620 catch (std::out_of_range e) {
00621 Log() << kWARNING << "Multiclass could not be evaluated for class-number " << clsNumber << Endl;
00622 return 0;
00623 }
00624 }
00625
00626
00627
00628 TMVA::IMethod* TMVA::Reader::FindMVA( const TString& methodTag )
00629 {
00630
00631 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
00632 if (it != fMethodMap.end()) return it->second;
00633 Log() << kERROR << "Method " << methodTag << " not found!" << Endl;
00634 return 0;
00635 }
00636
00637
00638 TMVA::MethodCuts* TMVA::Reader::FindCutsMVA( const TString& methodTag )
00639 {
00640
00641
00642 return dynamic_cast<MethodCuts*>(FindMVA(methodTag));
00643 }
00644
00645
00646 Double_t TMVA::Reader::GetProba( const TString& methodTag, Double_t ap_sig, Double_t mvaVal )
00647 {
00648
00649 IMethod* method = 0;
00650 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
00651 if (it == fMethodMap.end()) {
00652 for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << "M" << it->first << Endl;
00653 Log() << kFATAL << "<EvaluateMVA> unknown classifier in map: " << method << "; "
00654 << "you looked for " << methodTag<< " while the available methods are : " << Endl;
00655 }
00656 else method = it->second;
00657
00658 MethodBase* kl = dynamic_cast<MethodBase*>(method);
00659 if(kl==0) return -1;
00660
00661 if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
00662
00663 return kl->GetProba( mvaVal, ap_sig );
00664 }
00665
00666
00667 Double_t TMVA::Reader::GetRarity( const TString& methodTag, Double_t mvaVal )
00668 {
00669
00670 IMethod* method = 0;
00671 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
00672 if (it == fMethodMap.end()) {
00673 for (it = fMethodMap.begin(); it!=fMethodMap.end(); it++) Log() << "M" << it->first << Endl;
00674 Log() << kFATAL << "<EvaluateMVA> unknown classifier in map: \"" << method << "\"; "
00675 << "you looked for \"" << methodTag<< "\" while the available methods are : " << Endl;
00676 }
00677 else method = it->second;
00678
00679 MethodBase* kl = dynamic_cast<MethodBase*>(method);
00680 if(kl==0) return -1;
00681
00682 if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
00683
00684 return kl->GetRarity( mvaVal );
00685 }
00686
00687
00688
00689
00690
00691
00692 void TMVA::Reader::DecodeVarNames( const std::string& varNames )
00693 {
00694
00695 size_t ipos = 0, f = 0;
00696 while (f != varNames.length()) {
00697 f = varNames.find( ':', ipos );
00698 if (f > varNames.length()) f = varNames.length();
00699 std::string subs = varNames.substr( ipos, f-ipos ); ipos = f+1;
00700 DataInfo().AddVariable( subs.c_str() );
00701 }
00702 }
00703
00704
00705 void TMVA::Reader::DecodeVarNames( const TString& varNames )
00706 {
00707
00708
00709 TString format;
00710 Int_t n = varNames.Length();
00711 TString format_obj;
00712
00713 for (int i=0; i< n+1 ; i++) {
00714 format.Append(varNames(i));
00715 if (varNames(i) == ':' || i == n) {
00716 format.Chop();
00717 format_obj = format;
00718 format_obj.ReplaceAll("@","");
00719 DataInfo().AddVariable( format_obj );
00720 format.Resize(0);
00721 }
00722 }
00723 }