00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef ROOT_TMVA_DecisionTreeNode
00031 #define ROOT_TMVA_DecisionTreeNode
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041 #ifndef ROOT_TMVA_Node
00042 #include "TMVA/Node.h"
00043 #endif
00044
00045 #ifndef ROOT_TMVA_Version
00046 #include "TMVA/Version.h"
00047 #endif
00048
00049 #include <iostream>
00050 #include <vector>
00051 #include <map>
00052 namespace TMVA {
00053
00054 class DTNodeTrainingInfo
00055 {
00056 public:
00057 DTNodeTrainingInfo():fSampleMin(),
00058 fSampleMax(),
00059 fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
00060 fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
00061 fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
00062 fNEvents ( -1 ),
00063 fNSigEvents_unweighted ( 0 ),
00064 fNBkgEvents_unweighted ( 0 ),
00065 fNEvents_unweighted ( 0 ),
00066 fSeparationIndex (-1 ),
00067 fSeparationGain ( -1 )
00068 {
00069 }
00070 std::vector< Float_t > fSampleMin;
00071 std::vector< Float_t > fSampleMax;
00072 Double_t fNodeR;
00073 Double_t fSubTreeR;
00074 Double_t fAlpha;
00075 Double_t fG;
00076 Int_t fNTerminal;
00077 Double_t fNB;
00078 Double_t fNS;
00079 Float_t fSumTarget;
00080 Float_t fSumTarget2;
00081 Double_t fCC;
00082
00083 Float_t fNSigEvents;
00084 Float_t fNBkgEvents;
00085 Float_t fNEvents;
00086 Float_t fNSigEvents_unweighted;
00087 Float_t fNBkgEvents_unweighted;
00088 Float_t fNEvents_unweighted;
00089 Float_t fSeparationIndex;
00090 Float_t fSeparationGain;
00091
00092
00093 DTNodeTrainingInfo(const DTNodeTrainingInfo& n) :
00094 fSampleMin(),fSampleMax(),
00095 fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
00096 fAlpha(n.fAlpha), fG(n.fG),
00097 fNTerminal(n.fNTerminal),
00098 fNB(n.fNB), fNS(n.fNS),
00099 fSumTarget(0),fSumTarget2(0),
00100 fCC(0),
00101 fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
00102 fNEvents ( n.fNEvents ),
00103 fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
00104 fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
00105 fNEvents_unweighted ( n.fNEvents_unweighted ),
00106 fSeparationIndex( n.fSeparationIndex ),
00107 fSeparationGain ( n.fSeparationGain )
00108 { }
00109 };
00110
00111 class Event;
00112 class MsgLogger;
00113
00114 class DecisionTreeNode: public Node {
00115
00116 public:
00117
00118
00119 DecisionTreeNode ();
00120
00121 DecisionTreeNode (Node* p, char pos);
00122
00123
00124 DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = NULL);
00125
00126
00127 virtual ~DecisionTreeNode();
00128
00129 virtual Node* CreateNode() const { return new DecisionTreeNode(); }
00130
00131 inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
00132 inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
00133
00134 void SetFisherCoeff(Int_t ivar, Double_t coeff);
00135
00136 Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
00137
00138
00139 virtual Bool_t GoesRight( const Event & ) const;
00140
00141
00142 virtual Bool_t GoesLeft ( const Event & ) const;
00143
00144
00145 void SetSelector( Short_t i) { fSelector = i; }
00146
00147 Short_t GetSelector() const { return fSelector; }
00148
00149
00150 void SetCutValue ( Float_t c ) { fCutValue = c; }
00151
00152 Float_t GetCutValue ( void ) const { return fCutValue; }
00153
00154
00155 void SetCutType( Bool_t t ) { fCutType = t; }
00156
00157 Bool_t GetCutType( void ) const { return fCutType; }
00158
00159
00160 void SetNodeType( Int_t t ) { fNodeType = t;}
00161
00162 Int_t GetNodeType( void ) const { return fNodeType; }
00163
00164
00165 Float_t GetPurity( void ) const { return fPurity;}
00166
00167 void SetPurity( void );
00168
00169
00170 void SetResponse( Float_t r ) { fResponse = r;}
00171
00172
00173 Float_t GetResponse( void ) const { return fResponse;}
00174
00175
00176 void SetRMS( Float_t r ) { fRMS = r;}
00177
00178
00179 Float_t GetRMS( void ) const { return fRMS;}
00180
00181
00182 void SetNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents = s; }
00183
00184
00185 void SetNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents = b; }
00186
00187
00188 void SetNEvents( Float_t nev ){ fTrainInfo->fNEvents =nev ; }
00189
00190
00191 void SetNSigEvents_unweighted( Float_t s ) { fTrainInfo->fNSigEvents_unweighted = s; }
00192
00193
00194 void SetNBkgEvents_unweighted( Float_t b ) { fTrainInfo->fNBkgEvents_unweighted = b; }
00195
00196
00197 void SetNEvents_unweighted( Float_t nev ){ fTrainInfo->fNEvents_unweighted =nev ; }
00198
00199
00200 void IncrementNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents += s; }
00201
00202
00203 void IncrementNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents += b; }
00204
00205
00206 void IncrementNEvents( Float_t nev ){ fTrainInfo->fNEvents +=nev ; }
00207
00208
00209 void IncrementNSigEvents_unweighted( ) { fTrainInfo->fNSigEvents_unweighted += 1; }
00210
00211
00212 void IncrementNBkgEvents_unweighted( ) { fTrainInfo->fNBkgEvents_unweighted += 1; }
00213
00214
00215 void IncrementNEvents_unweighted( ){ fTrainInfo->fNEvents_unweighted +=1 ; }
00216
00217
00218 Float_t GetNSigEvents( void ) const { return fTrainInfo->fNSigEvents; }
00219
00220
00221 Float_t GetNBkgEvents( void ) const { return fTrainInfo->fNBkgEvents; }
00222
00223
00224 Float_t GetNEvents( void ) const { return fTrainInfo->fNEvents; }
00225
00226
00227 Float_t GetNSigEvents_unweighted( void ) const { return fTrainInfo->fNSigEvents_unweighted; }
00228
00229
00230 Float_t GetNBkgEvents_unweighted( void ) const { return fTrainInfo->fNBkgEvents_unweighted; }
00231
00232
00233 Float_t GetNEvents_unweighted( void ) const { return fTrainInfo->fNEvents_unweighted; }
00234
00235
00236
00237 void SetSeparationIndex( Float_t sep ){ fTrainInfo->fSeparationIndex =sep ; }
00238
00239 Float_t GetSeparationIndex( void ) const { return fTrainInfo->fSeparationIndex; }
00240
00241
00242 void SetSeparationGain( Float_t sep ){ fTrainInfo->fSeparationGain =sep ; }
00243
00244 Float_t GetSeparationGain( void ) const { return fTrainInfo->fSeparationGain; }
00245
00246
00247 virtual void Print( ostream& os ) const;
00248
00249
00250 virtual void PrintRec( ostream& os ) const;
00251
00252 virtual void AddAttributesToNode(void* node) const;
00253 virtual void AddContentToNode(std::stringstream& s) const;
00254
00255
00256 void ClearNodeAndAllDaughters();
00257
00258
00259
00260
00261 inline virtual DecisionTreeNode* GetLeft( ) const { return dynamic_cast<DecisionTreeNode*>(fLeft); }
00262 inline virtual DecisionTreeNode* GetRight( ) const { return dynamic_cast<DecisionTreeNode*>(fRight); }
00263 inline virtual DecisionTreeNode* GetParent( ) const { return dynamic_cast<DecisionTreeNode*>(fParent); }
00264
00265
00266 inline virtual void SetLeft (Node* l) { fLeft = dynamic_cast<DecisionTreeNode*>(l);}
00267 inline virtual void SetRight (Node* r) { fRight = dynamic_cast<DecisionTreeNode*>(r);}
00268 inline virtual void SetParent(Node* p) { fParent = dynamic_cast<DecisionTreeNode*>(p);}
00269
00270
00271
00272
00273
00274 inline void SetNodeR( Double_t r ) { fTrainInfo->fNodeR = r; }
00275 inline Double_t GetNodeR( ) const { return fTrainInfo->fNodeR; }
00276
00277
00278 inline void SetSubTreeR( Double_t r ) { fTrainInfo->fSubTreeR = r; }
00279 inline Double_t GetSubTreeR( ) const { return fTrainInfo->fSubTreeR; }
00280
00281
00282
00283
00284 inline void SetAlpha( Double_t alpha ) { fTrainInfo->fAlpha = alpha; }
00285 inline Double_t GetAlpha( ) const { return fTrainInfo->fAlpha; }
00286
00287
00288 inline void SetAlphaMinSubtree( Double_t g ) { fTrainInfo->fG = g; }
00289 inline Double_t GetAlphaMinSubtree( ) const { return fTrainInfo->fG; }
00290
00291
00292 inline void SetNTerminal( Int_t n ) { fTrainInfo->fNTerminal = n; }
00293 inline Int_t GetNTerminal( ) const { return fTrainInfo->fNTerminal; }
00294
00295
00296 inline void SetNBValidation( Double_t b ) { fTrainInfo->fNB = b; }
00297 inline void SetNSValidation( Double_t s ) { fTrainInfo->fNS = s; }
00298 inline Double_t GetNBValidation( ) const { return fTrainInfo->fNB; }
00299 inline Double_t GetNSValidation( ) const { return fTrainInfo->fNS; }
00300
00301
00302 inline void SetSumTarget(Float_t t) {fTrainInfo->fSumTarget = t; }
00303 inline void SetSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 = t2; }
00304
00305 inline void AddToSumTarget(Float_t t) {fTrainInfo->fSumTarget += t; }
00306 inline void AddToSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 += t2; }
00307
00308 inline Float_t GetSumTarget() const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
00309 inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
00310
00311
00312
00313 void ResetValidationData( );
00314
00315
00316 inline Bool_t IsTerminal() const { return fIsTerminalNode; }
00317 inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
00318 void PrintPrune( ostream& os ) const ;
00319 void PrintRecPrune( ostream& os ) const;
00320
00321 void SetCC(Double_t cc);
00322 Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
00323
00324 Float_t GetSampleMin(UInt_t ivar) const;
00325 Float_t GetSampleMax(UInt_t ivar) const;
00326 void SetSampleMin(UInt_t ivar, Float_t xmin);
00327 void SetSampleMax(UInt_t ivar, Float_t xmax);
00328
00329 static bool fgIsTraining;
00330
00331 protected:
00332
00333 static MsgLogger* fgLogger;
00334
00335 std::vector<Double_t> fFisherCoeff;
00336
00337 Float_t fCutValue;
00338 Bool_t fCutType;
00339 Short_t fSelector;
00340
00341 Float_t fResponse;
00342 Float_t fRMS;
00343 Int_t fNodeType;
00344 Float_t fPurity;
00345
00346 Bool_t fIsTerminalNode;
00347
00348 mutable DTNodeTrainingInfo* fTrainInfo;
00349
00350 private:
00351
00352 virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
00353 virtual Bool_t ReadDataRecord( istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
00354 virtual void ReadContent(std::stringstream& s);
00355
00356 ClassDef(DecisionTreeNode,0)
00357 };
00358 }
00359
00360 #endif