00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #ifndef ROOT_TMVA_MethodBase
00032 #define ROOT_TMVA_MethodBase
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042 #include <iosfwd>
00043 #include <vector>
00044 #include <map>
00045 #include "assert.h"
00046
00047 #ifndef ROOT_TString
00048 #include "TString.h"
00049 #endif
00050
00051 #ifndef ROOT_TMVA_IMethod
00052 #include "TMVA/IMethod.h"
00053 #endif
00054 #ifndef ROOT_TMVA_Configurable
00055 #include "TMVA/Configurable.h"
00056 #endif
00057 #ifndef ROOT_TMVA_Types
00058 #include "TMVA/Types.h"
00059 #endif
00060 #ifndef ROOT_TMVA_DataSet
00061 #include "TMVA/DataSet.h"
00062 #endif
00063 #ifndef ROOT_TMVA_Event
00064 #include "TMVA/Event.h"
00065 #endif
00066 #ifndef ROOT_TMVA_TransformationHandler
00067 #include "TMVA/TransformationHandler.h"
00068 #endif
00069 #ifndef ROOT_TMVA_OptimizeConfigParameters
00070 #include "TMVA/OptimizeConfigParameters.h"
00071 #endif
00072
00073 class TGraph;
00074 class TTree;
00075 class TDirectory;
00076 class TSpline;
00077
00078 namespace TMVA {
00079
00080 class Ranking;
00081 class PDF;
00082 class TSpline1;
00083 class MethodCuts;
00084 class MethodBoost;
00085 class DataSetInfo;
00086
00087 class MethodBase : virtual public IMethod, public Configurable {
00088
00089 public:
00090
00091 enum EWeightFileType { kROOT=0, kTEXT };
00092
00093
00094 MethodBase( const TString& jobName,
00095 Types::EMVA methodType,
00096 const TString& methodTitle,
00097 DataSetInfo& dsi,
00098 const TString& theOption = "",
00099 TDirectory* theBaseDir = 0 );
00100
00101
00102
00103 MethodBase( Types::EMVA methodType,
00104 DataSetInfo& dsi,
00105 const TString& weightFile,
00106 TDirectory* theBaseDir = 0 );
00107
00108
00109 virtual ~MethodBase();
00110
00111
00112 void SetupMethod();
00113 void ProcessSetup();
00114 virtual void CheckSetup();
00115
00116
00117
00118
00119 void AddOutput( Types::ETreeType type, Types::EAnalysisType analysisType );
00120
00121
00122
00123 void TrainMethod();
00124
00125
00126 virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA");
00127 virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
00128
00129 virtual void Train() = 0;
00130
00131
00132 void SetTrainTime( Double_t trainTime ) { fTrainTime = trainTime; }
00133 Double_t GetTrainTime() const { return fTrainTime; }
00134
00135
00136 void SetTestTime ( Double_t testTime ) { fTestTime = testTime; }
00137 Double_t GetTestTime () const { return fTestTime; }
00138
00139
00140 virtual void TestClassification();
00141
00142
00143 virtual void TestMulticlass();
00144
00145
00146 virtual void TestRegression( Double_t& bias, Double_t& biasT,
00147 Double_t& dev, Double_t& devT,
00148 Double_t& rms, Double_t& rmsT,
00149 Double_t& mInf, Double_t& mInfT,
00150 Double_t& corr,
00151 Types::ETreeType type );
00152
00153
00154 virtual void Init() = 0;
00155 virtual void DeclareOptions() = 0;
00156 virtual void ProcessOptions() = 0;
00157 virtual void DeclareCompatibilityOptions();
00158
00159
00160
00161
00162
00163 virtual void Reset(){return;}
00164
00165
00166
00167
00168 virtual Double_t GetMvaValue( Double_t* errLower = 0, Double_t* errUpper = 0) = 0;
00169
00170
00171 Double_t GetMvaValue( const TMVA::Event* const ev, Double_t* err = 0, Double_t* errUpper = 0 );
00172
00173 protected:
00174
00175 void NoErrorCalc(Double_t* const err, Double_t* const errUpper);
00176
00177 public:
00178
00179 virtual const std::vector<Float_t>& GetRegressionValues() {
00180 std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
00181 return (*ptr);
00182 }
00183
00184
00185 virtual const std::vector<Float_t>& GetMulticlassValues() {
00186 std::vector<Float_t>* ptr = new std::vector<Float_t>(0);
00187 return (*ptr);
00188 }
00189
00190
00191 virtual Double_t GetProba( Double_t mvaVal, Double_t ap_sig );
00192
00193
00194 virtual Double_t GetRarity( Double_t mvaVal, Types::ESBType reftype = Types::kBackground ) const;
00195
00196
00197 virtual const Ranking* CreateRanking() = 0;
00198
00199
00200 virtual Bool_t MonitorBoost(MethodBoost* ) {return kFALSE;};
00201
00202
00203 virtual void MakeClass( const TString& classFileName = TString("") ) const;
00204
00205
00206 void PrintHelpMessage() const;
00207
00208
00209
00210
00211 public:
00212 void WriteStateToFile () const;
00213 void ReadStateFromFile ();
00214
00215 protected:
00216
00217 virtual void AddWeightsXMLTo ( void* parent ) const = 0;
00218 virtual void ReadWeightsFromXML ( void* wghtnode ) = 0;
00219 virtual void ReadWeightsFromStream( std::istream& ) = 0;
00220 virtual void ReadWeightsFromStream( TFile& ) {}
00221
00222 private:
00223 friend class MethodCategory;
00224 friend class MethodCommittee;
00225 friend class MethodCompositeBase;
00226 void WriteStateToXML ( void* parent ) const;
00227 void ReadStateFromXML ( void* parent );
00228 void WriteStateToStream ( std::ostream& tf ) const;
00229 void WriteVarsToStream ( std::ostream& tf, const TString& prefix = "" ) const;
00230
00231
00232 public:
00233 void ReadStateFromStream ( std::istream& tf );
00234 void ReadStateFromStream ( TFile& rf );
00235 void ReadStateFromXMLString( const char* xmlstr );
00236
00237 private:
00238
00239 void AddVarsXMLTo ( void* parent ) const;
00240 void AddSpectatorsXMLTo ( void* parent ) const;
00241 void AddTargetsXMLTo ( void* parent ) const;
00242 void AddClassesXMLTo ( void* parent ) const;
00243 void ReadVariablesFromXML ( void* varnode );
00244 void ReadSpectatorsFromXML( void* specnode);
00245 void ReadTargetsFromXML ( void* tarnode );
00246 void ReadClassesFromXML ( void* clsnode );
00247 void ReadVarsFromStream ( std::istream& istr );
00248
00249 public:
00250
00251
00252
00253 virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype);
00254
00255
00256 virtual void WriteMonitoringHistosToFile() const;
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268 virtual Double_t GetEfficiency( const TString&, Types::ETreeType, Double_t& err );
00269 virtual Double_t GetTrainingEfficiency(const TString& );
00270 virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
00271 virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
00272 virtual Double_t GetSignificance() const;
00273 virtual Double_t GetROCIntegral(PDF *pdfS=0, PDF *pdfB=0) const;
00274 virtual Double_t GetMaximumSignificance( Double_t SignalEvents, Double_t BackgroundEvents,
00275 Double_t& optimal_significance_value ) const;
00276 virtual Double_t GetSeparation( TH1*, TH1* ) const;
00277 virtual Double_t GetSeparation( PDF* pdfS = 0, PDF* pdfB = 0 ) const;
00278
00279 virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t& stddev,Double_t& stddev90Percent ) const;
00280
00281
00282
00283 const TString& GetJobName () const { return fJobName; }
00284 const TString& GetMethodName () const { return fMethodName; }
00285 TString GetMethodTypeName() const { return Types::Instance().GetMethodName(fMethodType); }
00286 Types::EMVA GetMethodType () const { return fMethodType; }
00287 const char* GetName () const { return fMethodName.Data(); }
00288 const TString& GetTestvarName () const { return fTestvar; }
00289 const TString GetProbaName () const { return fTestvar + "_Proba"; }
00290 TString GetWeightFileName() const;
00291
00292
00293
00294 void SetTestvarName ( const TString & v="" ) { fTestvar = (v=="") ? ("MVA_" + GetMethodName()) : v; }
00295
00296
00297 UInt_t GetNvar() const { return DataInfo().GetNVariables(); }
00298 UInt_t GetNVariables() const { return DataInfo().GetNVariables(); }
00299 UInt_t GetNTargets() const { return DataInfo().GetNTargets(); };
00300
00301
00302 const TString& GetInputVar ( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetInternalName(); }
00303 const TString& GetInputLabel( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetLabel(); }
00304 const TString& GetInputTitle( Int_t i ) const { return DataInfo().GetVariableInfo(i).GetTitle(); }
00305
00306
00307 Double_t GetMean( Int_t ivar ) const { return GetTransformationHandler().GetMean(ivar); }
00308 Double_t GetRMS ( Int_t ivar ) const { return GetTransformationHandler().GetRMS(ivar); }
00309 Double_t GetXmin( Int_t ivar ) const { return GetTransformationHandler().GetMin(ivar); }
00310 Double_t GetXmax( Int_t ivar ) const { return GetTransformationHandler().GetMax(ivar); }
00311
00312
00313 Double_t GetSignalReferenceCut() const { return fSignalReferenceCut; }
00314
00315
00316 void SetSignalReferenceCut( Double_t cut ) { fSignalReferenceCut = cut; }
00317
00318
00319 TDirectory* BaseDir() const;
00320 TDirectory* MethodBaseDir() const;
00321 void SetMethodDir ( TDirectory* methodDir ) { fBaseDir = fMethodBaseDir = methodDir; }
00322 void SetBaseDir( TDirectory* methodDir ){ fBaseDir = methodDir; }
00323 void SetMethodBaseDir( TDirectory* methodDir ){ fMethodBaseDir = methodDir; }
00324
00325
00326
00327
00328
00329 UInt_t GetTrainingTMVAVersionCode() const { return fTMVATrainingVersion; }
00330 UInt_t GetTrainingROOTVersionCode() const { return fROOTTrainingVersion; }
00331 TString GetTrainingTMVAVersionString() const;
00332 TString GetTrainingROOTVersionString() const;
00333
00334 TransformationHandler& GetTransformationHandler() { return fTransformation; }
00335 const TransformationHandler& GetTransformationHandler() const { return fTransformation; }
00336
00337
00338
00339
00340 DataSetInfo& DataInfo() const { return fDataSetInfo; }
00341
00342 mutable const Event* fTmpEvent;
00343
00344
00345 UInt_t GetNEvents () const { return Data()->GetNEvents(); }
00346 const Event* GetEvent () const;
00347 const Event* GetEvent ( const TMVA::Event* ev ) const;
00348 const Event* GetEvent ( Long64_t ievt ) const;
00349 const Event* GetEvent ( Long64_t ievt , Types::ETreeType type ) const;
00350 const Event* GetTrainingEvent( Long64_t ievt ) const;
00351 const Event* GetTestingEvent ( Long64_t ievt ) const;
00352 const std::vector<TMVA::Event*>& GetEventCollection( Types::ETreeType type );
00353
00354
00355
00356
00357
00358
00359 virtual Bool_t IsSignalLike() { return GetMvaValue() > GetSignalReferenceCut() ? kTRUE : kFALSE; }
00360
00361 DataSet* Data() const { return DataInfo().GetDataSet(); }
00362
00363 Bool_t HasMVAPdfs() const { return fHasMVAPdfs; }
00364 virtual void SetAnalysisType( Types::EAnalysisType type ) { fAnalysisType = type; }
00365 Types::EAnalysisType GetAnalysisType() const { return fAnalysisType; }
00366 Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
00367 Bool_t DoMulticlass() const { return fAnalysisType == Types::kMulticlass; }
00368
00369
00370 void DisableWriting(Bool_t setter){ fDisableWriting = setter; }
00371
00372 protected:
00373
00374
00375
00376
00377
00378
00379 void SetWeightFileName( TString );
00380
00381 const TString& GetWeightFileDir() const { return fFileDir; }
00382 void SetWeightFileDir( TString fileDir );
00383
00384
00385 Bool_t IsNormalised() const { return fNormalise; }
00386 void SetNormalised( Bool_t norm ) { fNormalise = norm; }
00387
00388
00389
00390
00391
00392 Bool_t Verbose() const { return fVerbose; }
00393 Bool_t Help () const { return fHelp; }
00394
00395
00396
00397
00398
00399 const TString& GetInternalVarName( Int_t ivar ) const { return (*fInputVars)[ivar]; }
00400 const TString& GetOriginalVarName( Int_t ivar ) const { return DataInfo().GetVariableInfo(ivar).GetExpression(); }
00401
00402 Bool_t HasTrainingTree() const { return Data()->GetNTrainingEvents() != 0; }
00403
00404
00405
00406 protected:
00407
00408
00409 virtual void MakeClassSpecific( std::ostream&, const TString& = "" ) const {}
00410
00411
00412 virtual void MakeClassSpecificHeader( std::ostream&, const TString& = "" ) const {}
00413
00414
00415 static MethodBase* GetThisBase();
00416
00417
00418 void Statistics( Types::ETreeType treeType, const TString& theVarName,
00419 Double_t&, Double_t&, Double_t&,
00420 Double_t&, Double_t&, Double_t& );
00421
00422
00423 Bool_t TxtWeightsOnly() const { return kTRUE; }
00424
00425 protected:
00426
00427
00428
00429 Float_t GetTWeight( const Event* ev ) const {
00430 return (fIgnoreNegWeightsInTraining && (ev->GetWeight() < 0)) ? 0. : ev->GetWeight();
00431 }
00432
00433 Bool_t IsConstructedFromWeightFile() const { return fConstructedFromWeightFile; }
00434
00435 public:
00436 virtual void SetCurrentEvent( Long64_t ievt ) const {
00437 Data()->SetCurrentEvent(ievt);
00438 }
00439
00440
00441 private:
00442
00443
00444
00445 void InitBase();
00446 void DeclareBaseOptions();
00447 void ProcessBaseOptions();
00448
00449
00450 enum ECutOrientation { kNegative = -1, kPositive = +1 };
00451 ECutOrientation GetCutOrientation() const { return fCutOrientation; }
00452
00453
00454
00455
00456 void ResetThisBase();
00457
00458
00459
00460
00461 void CreateMVAPdfs();
00462
00463
00464 static Double_t IGetEffForRoot( Double_t );
00465 Double_t GetEffForRoot ( Double_t );
00466
00467
00468 Bool_t GetLine( std::istream& fin, char * buf );
00469
00470
00471 virtual void AddClassifierOutput ( Types::ETreeType type );
00472 virtual void AddClassifierOutputProb( Types::ETreeType type );
00473 virtual void AddRegressionOutput ( Types::ETreeType type );
00474 virtual void AddMulticlassOutput ( Types::ETreeType type );
00475
00476 private:
00477
00478 void AddInfoItem( void* gi, const TString& name, const TString& value) const;
00479 void CreateVariableTransforms(const TString& trafoDefinition );
00480
00481
00482
00483
00484 protected:
00485
00486
00487 Ranking* fRanking;
00488 std::vector<TString>* fInputVars;
00489
00490
00491 Int_t fNbins;
00492 Int_t fNbinsH;
00493
00494 Types::EAnalysisType fAnalysisType;
00495
00496 std::vector<Float_t>* fRegressionReturnVal;
00497 std::vector<Float_t>* fMulticlassReturnVal;
00498
00499 private:
00500
00501
00502 friend class MethodCuts;
00503
00504 Bool_t fDisableWriting;
00505
00506
00507 DataSetInfo& fDataSetInfo;
00508
00509 Double_t fSignalReferenceCut;
00510 Types::ESBType fVariableTransformType;
00511
00512
00513 TString fJobName;
00514 TString fMethodName;
00515 Types::EMVA fMethodType;
00516 TString fTestvar;
00517 UInt_t fTMVATrainingVersion;
00518 UInt_t fROOTTrainingVersion;
00519 Bool_t fConstructedFromWeightFile;
00520
00521
00522
00523
00524 TDirectory* fBaseDir;
00525 mutable TDirectory* fMethodBaseDir;
00526
00527 TString fParentDir;
00528
00529 TString fFileDir;
00530 TString fWeightFile;
00531
00532 private:
00533
00534 TH1* fEffS;
00535
00536 PDF* fDefaultPDF;
00537 PDF* fMVAPdfS;
00538 PDF* fMVAPdfB;
00539
00540 PDF* fSplS;
00541 PDF* fSplB;
00542 TSpline* fSpleffBvsS;
00543
00544 PDF* fSplTrainS;
00545 PDF* fSplTrainB;
00546 TSpline* fSplTrainEffBvsS;
00547
00548 private:
00549
00550
00551 Double_t fMeanS;
00552 Double_t fMeanB;
00553 Double_t fRmsS;
00554 Double_t fRmsB;
00555 Double_t fXmin;
00556 Double_t fXmax;
00557
00558
00559 TString fVarTransformString;
00560
00561 TransformationHandler fTransformation;
00562
00563
00564
00565 Bool_t fVerbose;
00566 TString fVerbosityLevelString;
00567 EMsgType fVerbosityLevel;
00568 Bool_t fHelp;
00569 Bool_t fHasMVAPdfs;
00570
00571 Bool_t fIgnoreNegWeightsInTraining;
00572
00573 protected:
00574
00575 Bool_t IgnoreEventsWithNegWeightsInTraining() const { return fIgnoreNegWeightsInTraining; }
00576
00577
00578 UInt_t fSignalClass;
00579 UInt_t fBackgroundClass;
00580
00581 private:
00582
00583
00584 Double_t fTrainTime;
00585 Double_t fTestTime;
00586
00587
00588 ECutOrientation fCutOrientation;
00589
00590
00591 TSpline1* fSplRefS;
00592 TSpline1* fSplRefB;
00593
00594 TSpline1* fSplTrainRefS;
00595 TSpline1* fSplTrainRefB;
00596
00597 mutable std::vector<const std::vector<TMVA::Event*>*> fEventCollections;
00598
00599 public:
00600 Bool_t fSetupCompleted;
00601
00602 private:
00603
00604
00605 static MethodBase* fgThisBase;
00606
00607
00608
00609 private:
00610
00611 Bool_t fNormalise;
00612 Bool_t fUseDecorr;
00613 TString fVariableTransformTypeString;
00614 Bool_t fTxtWeightsOnly;
00615 Int_t fNbinsMVAPdf;
00616 Int_t fNsmoothMVAPdf;
00617
00618 protected:
00619
00620 ClassDef(MethodBase,0)
00621
00622 };
00623 }
00624
00625
00626
00627
00628
00629
00630
00631
00632
00633
00634
00635 inline const TMVA::Event* TMVA::MethodBase::GetEvent( const TMVA::Event* ev ) const
00636 {
00637 return GetTransformationHandler().Transform(ev);
00638 }
00639
00640 inline const TMVA::Event* TMVA::MethodBase::GetEvent() const
00641 {
00642 if(fTmpEvent)
00643 return GetTransformationHandler().Transform(fTmpEvent);
00644 else
00645 return GetTransformationHandler().Transform(Data()->GetEvent());
00646 }
00647
00648 inline const TMVA::Event* TMVA::MethodBase::GetEvent( Long64_t ievt ) const
00649 {
00650 assert(fTmpEvent==0);
00651 return GetTransformationHandler().Transform(Data()->GetEvent(ievt));
00652 }
00653
00654 inline const TMVA::Event* TMVA::MethodBase::GetEvent( Long64_t ievt, Types::ETreeType type ) const
00655 {
00656 assert(fTmpEvent==0);
00657 return GetTransformationHandler().Transform(Data()->GetEvent(ievt, type));
00658 }
00659
00660 inline const TMVA::Event* TMVA::MethodBase::GetTrainingEvent( Long64_t ievt ) const
00661 {
00662 assert(fTmpEvent==0);
00663 return GetEvent(ievt, Types::kTraining);
00664 }
00665
00666 inline const TMVA::Event* TMVA::MethodBase::GetTestingEvent( Long64_t ievt ) const
00667 {
00668 assert(fTmpEvent==0);
00669 return GetEvent(ievt, Types::kTesting);
00670 }
00671
00672 #endif