00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039 #include <algorithm>
00040 #include <exception>
00041 #include <iomanip>
00042
00043 #include "TMVA/MsgLogger.h"
00044 #include "TMVA/DecisionTreeNode.h"
00045 #include "TMVA/Tools.h"
00046 #include "TMVA/Event.h"
00047
00048 using std::string;
00049
00050 ClassImp(TMVA::DecisionTreeNode)
00051
00052 TMVA::MsgLogger* TMVA::DecisionTreeNode::fgLogger = 0;
00053 bool TMVA::DecisionTreeNode::fgIsTraining = false;
00054
00055
00056 TMVA::DecisionTreeNode::DecisionTreeNode()
00057 : TMVA::Node(),
00058 fCutValue(0),
00059 fCutType ( kTRUE ),
00060 fSelector ( -1 ),
00061 fResponse(-99 ),
00062 fRMS(0),
00063 fNodeType (-99 ),
00064 fPurity (-99),
00065 fIsTerminalNode( kFALSE )
00066 {
00067
00068 if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
00069
00070 if (fgIsTraining){
00071 fTrainInfo = new DTNodeTrainingInfo();
00072
00073 }
00074 else {
00075
00076 fTrainInfo = 0;
00077 }
00078 }
00079
00080
00081 TMVA::DecisionTreeNode::DecisionTreeNode(TMVA::Node* p, char pos)
00082 : TMVA::Node(p, pos),
00083 fCutValue( 0 ),
00084 fCutType ( kTRUE ),
00085 fSelector( -1 ),
00086 fResponse(-99 ),
00087 fRMS(0),
00088 fNodeType( -99 ),
00089 fPurity (-99),
00090 fIsTerminalNode( kFALSE )
00091 {
00092
00093 if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
00094
00095 if (fgIsTraining){
00096 fTrainInfo = new DTNodeTrainingInfo();
00097
00098 }
00099 else {
00100
00101 fTrainInfo = 0;
00102 }
00103 }
00104
00105
00106 TMVA::DecisionTreeNode::DecisionTreeNode(const TMVA::DecisionTreeNode &n,
00107 DecisionTreeNode* parent)
00108 : TMVA::Node(n),
00109 fCutValue( n.fCutValue ),
00110 fCutType ( n.fCutType ),
00111 fSelector( n.fSelector ),
00112 fResponse( n.fResponse ),
00113 fRMS ( n.fRMS),
00114 fNodeType( n.fNodeType ),
00115 fPurity ( n.fPurity),
00116 fIsTerminalNode( n.fIsTerminalNode )
00117 {
00118
00119
00120 if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
00121
00122 this->SetParent( parent );
00123 if (n.GetLeft() == 0 ) this->SetLeft(NULL);
00124 else this->SetLeft( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetLeft())),this));
00125
00126 if (n.GetRight() == 0 ) this->SetRight(NULL);
00127 else this->SetRight( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetRight())),this));
00128
00129 if (fgIsTraining){
00130 fTrainInfo = new DTNodeTrainingInfo(*(n.fTrainInfo));
00131
00132 }
00133 else {
00134
00135 fTrainInfo = 0;
00136 }
00137 }
00138
00139
00140 TMVA::DecisionTreeNode::~DecisionTreeNode(){
00141
00142 delete fTrainInfo;
00143 }
00144
00145
00146
00147 Bool_t TMVA::DecisionTreeNode::GoesRight(const TMVA::Event & e) const
00148 {
00149
00150 Bool_t result;
00151
00152 if (GetNFisherCoeff() == 0){
00153
00154 result = (e.GetValue(this->GetSelector()) > this->GetCutValue() );
00155
00156 }else{
00157
00158 Double_t fisher = this->GetFisherCoeff(fFisherCoeff.size()-1);
00159 for (UInt_t ivar=0; ivar<fFisherCoeff.size()-1; ivar++)
00160 fisher += this->GetFisherCoeff(ivar)*(e.GetValue(ivar));
00161
00162 result = fisher > this->GetCutValue();
00163 }
00164
00165 if (fCutType == kTRUE) return result;
00166 else return !result;
00167 }
00168
00169
00170 Bool_t TMVA::DecisionTreeNode::GoesLeft(const TMVA::Event & e) const
00171 {
00172
00173 if (!this->GoesRight(e)) return kTRUE;
00174 else return kFALSE;
00175 }
00176
00177
00178
00179 void TMVA::DecisionTreeNode::SetPurity( void )
00180 {
00181
00182
00183
00184
00185 if ( ( this->GetNSigEvents() + this->GetNBkgEvents() ) > 0 ) {
00186 fPurity = this->GetNSigEvents() / ( this->GetNSigEvents() + this->GetNBkgEvents());
00187 }
00188 else {
00189 *fgLogger << kINFO << "Zero events in purity calcuation , return purity=0.5" << Endl;
00190 this->Print(*fgLogger);
00191 fPurity = 0.5;
00192 }
00193 return;
00194 }
00195
00196
00197
00198 void TMVA::DecisionTreeNode::Print(ostream& os) const
00199 {
00200
00201 os << "< *** " << std::endl;
00202 os << " d: " << this->GetDepth()
00203 << std::setprecision(6)
00204 << "NCoef: " << this->GetNFisherCoeff();
00205 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) { os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
00206 os << " ivar: " << this->GetSelector()
00207 << " cut: " << this->GetCutValue()
00208 << " cType: " << this->GetCutType()
00209 << " s: " << this->GetNSigEvents()
00210 << " b: " << this->GetNBkgEvents()
00211 << " nEv: " << this->GetNEvents()
00212 << " suw: " << this->GetNSigEvents_unweighted()
00213 << " buw: " << this->GetNBkgEvents_unweighted()
00214 << " nEvuw: " << this->GetNEvents_unweighted()
00215 << " sepI: " << this->GetSeparationIndex()
00216 << " sepG: " << this->GetSeparationGain()
00217 << " nType: " << this->GetNodeType()
00218 << std::endl;
00219
00220 os << "My address is " << long(this) << ", ";
00221 if (this->GetParent() != NULL) os << " parent at addr: " << long(this->GetParent()) ;
00222 if (this->GetLeft() != NULL) os << " left daughter at addr: " << long(this->GetLeft());
00223 if (this->GetRight() != NULL) os << " right daughter at addr: " << long(this->GetRight()) ;
00224
00225 os << " **** > " << std::endl;
00226 }
00227
00228
00229 void TMVA::DecisionTreeNode::PrintRec(ostream& os) const
00230 {
00231
00232
00233 os << this->GetDepth()
00234 << std::setprecision(6)
00235 << " " << this->GetPos()
00236 << "NCoef: " << this->GetNFisherCoeff();
00237 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
00238 os << " ivar: " << this->GetSelector()
00239 << " cut: " << this->GetCutValue()
00240 << " cType: " << this->GetCutType()
00241 << " s: " << this->GetNSigEvents()
00242 << " b: " << this->GetNBkgEvents()
00243 << " nEv: " << this->GetNEvents()
00244 << " suw: " << this->GetNSigEvents_unweighted()
00245 << " buw: " << this->GetNBkgEvents_unweighted()
00246 << " nEvuw: " << this->GetNEvents_unweighted()
00247 << " sepI: " << this->GetSeparationIndex()
00248 << " sepG: " << this->GetSeparationGain()
00249 << " res: " << this->GetResponse()
00250 << " rms: " << this->GetRMS()
00251 << " nType: " << this->GetNodeType();
00252 if (this->GetCC() > 10000000000000.) os << " CC: " << 100000. << std::endl;
00253 else os << " CC: " << this->GetCC() << std::endl;
00254
00255 if (this->GetLeft() != NULL) this->GetLeft() ->PrintRec(os);
00256 if (this->GetRight() != NULL) this->GetRight()->PrintRec(os);
00257 }
00258
00259
00260 Bool_t TMVA::DecisionTreeNode::ReadDataRecord( istream& is, UInt_t tmva_Version_Code )
00261 {
00262
00263
00264 string tmp;
00265
00266 Float_t cutVal, cutType, nsig, nbkg, nEv, nsig_unweighted, nbkg_unweighted, nEv_unweighted;
00267 Float_t separationIndex, separationGain, response(-99), cc(0);
00268 Int_t depth, ivar, nodeType;
00269 ULong_t lseq;
00270 char pos;
00271
00272 is >> depth;
00273 if ( depth==-1 ) { return kFALSE; }
00274
00275 is >> pos ;
00276 this->SetDepth(depth);
00277 this->SetPos(pos);
00278
00279 if (tmva_Version_Code < TMVA_VERSION(4,0,0)) {
00280 is >> tmp >> lseq
00281 >> tmp >> ivar
00282 >> tmp >> cutVal
00283 >> tmp >> cutType
00284 >> tmp >> nsig
00285 >> tmp >> nbkg
00286 >> tmp >> nEv
00287 >> tmp >> nsig_unweighted
00288 >> tmp >> nbkg_unweighted
00289 >> tmp >> nEv_unweighted
00290 >> tmp >> separationIndex
00291 >> tmp >> separationGain
00292 >> tmp >> nodeType;
00293 } else {
00294 is >> tmp >> lseq
00295 >> tmp >> ivar
00296 >> tmp >> cutVal
00297 >> tmp >> cutType
00298 >> tmp >> nsig
00299 >> tmp >> nbkg
00300 >> tmp >> nEv
00301 >> tmp >> nsig_unweighted
00302 >> tmp >> nbkg_unweighted
00303 >> tmp >> nEv_unweighted
00304 >> tmp >> separationIndex
00305 >> tmp >> separationGain
00306 >> tmp >> response
00307 >> tmp >> nodeType
00308 >> tmp >> cc;
00309 }
00310
00311 this->SetSelector((UInt_t)ivar);
00312 this->SetCutValue(cutVal);
00313 this->SetCutType(cutType);
00314 this->SetNodeType(nodeType);
00315 if (fTrainInfo){
00316 this->SetNSigEvents(nsig);
00317 this->SetNBkgEvents(nbkg);
00318 this->SetNEvents(nEv);
00319 this->SetNSigEvents_unweighted(nsig_unweighted);
00320 this->SetNBkgEvents_unweighted(nbkg_unweighted);
00321 this->SetNEvents_unweighted(nEv_unweighted);
00322 this->SetSeparationIndex(separationIndex);
00323 this->SetSeparationGain(separationGain);
00324 this->SetPurity();
00325
00326 this->SetCC(cc);
00327 }
00328
00329 return kTRUE;
00330 }
00331
00332
00333 void TMVA::DecisionTreeNode::ClearNodeAndAllDaughters()
00334 {
00335
00336 SetNSigEvents(0);
00337 SetNBkgEvents(0);
00338 SetNEvents(0);
00339 SetNSigEvents_unweighted(0);
00340 SetNBkgEvents_unweighted(0);
00341 SetNEvents_unweighted(0);
00342 SetSeparationIndex(-1);
00343 SetSeparationGain(-1);
00344 SetPurity();
00345
00346 if (this->GetLeft() != NULL) ((DecisionTreeNode*)(this->GetLeft()))->ClearNodeAndAllDaughters();
00347 if (this->GetRight() != NULL) ((DecisionTreeNode*)(this->GetRight()))->ClearNodeAndAllDaughters();
00348 }
00349
00350
00351 void TMVA::DecisionTreeNode::ResetValidationData( ) {
00352
00353
00354 SetNBValidation( 0.0 );
00355 SetNSValidation( 0.0 );
00356 SetSumTarget( 0 );
00357 SetSumTarget2( 0 );
00358
00359 if(GetLeft() != NULL && GetRight() != NULL) {
00360 GetLeft()->ResetValidationData();
00361 GetRight()->ResetValidationData();
00362 }
00363 }
00364
00365
00366 void TMVA::DecisionTreeNode::PrintPrune( ostream& os ) const {
00367
00368
00369 os << "----------------------" << std::endl
00370 << "|~T_t| " << GetNTerminal() << std::endl
00371 << "R(t): " << GetNodeR() << std::endl
00372 << "R(T_t): " << GetSubTreeR() << std::endl
00373 << "g(t): " << GetAlpha() << std::endl
00374 << "G(t): " << GetAlphaMinSubtree() << std::endl;
00375 }
00376
00377
00378 void TMVA::DecisionTreeNode::PrintRecPrune( ostream& os ) const {
00379
00380
00381 this->PrintPrune(os);
00382 if(this->GetLeft() != NULL && this->GetRight() != NULL) {
00383 ((DecisionTreeNode*)this->GetLeft())->PrintRecPrune(os);
00384 ((DecisionTreeNode*)this->GetRight())->PrintRecPrune(os);
00385 }
00386 }
00387
00388
00389 void TMVA::DecisionTreeNode::SetCC(Double_t cc)
00390 {
00391 if (fTrainInfo) fTrainInfo->fCC = cc;
00392 else *fgLogger << kFATAL << "call to SetCC without trainingInfo" << Endl;
00393 }
00394
00395
00396 Float_t TMVA::DecisionTreeNode::GetSampleMin(UInt_t ivar) const {
00397
00398
00399 if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMin[ivar];
00400 else *fgLogger << kFATAL << "You asked for Min of the event sample in node for variable "
00401 << ivar << " that is out of range" << Endl;
00402 return -9999;
00403 }
00404
00405
00406 Float_t TMVA::DecisionTreeNode::GetSampleMax(UInt_t ivar) const {
00407
00408
00409 if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMax[ivar];
00410 else *fgLogger << kFATAL << "You asked for Max of the event sample in node for variable "
00411 << ivar << " that is out of range" << Endl;
00412 return 9999;
00413 }
00414
00415
00416 void TMVA::DecisionTreeNode::SetSampleMin(UInt_t ivar, Float_t xmin){
00417
00418
00419 if ( fTrainInfo) {
00420 if ( ivar >= fTrainInfo->fSampleMin.size()) fTrainInfo->fSampleMin.resize(ivar+1);
00421 fTrainInfo->fSampleMin[ivar]=xmin;
00422 }
00423 }
00424
00425
00426 void TMVA::DecisionTreeNode::SetSampleMax(UInt_t ivar, Float_t xmax){
00427
00428
00429 if( ! fTrainInfo ) return;
00430 if ( ivar >= fTrainInfo->fSampleMax.size() )
00431 fTrainInfo->fSampleMax.resize(ivar+1);
00432 fTrainInfo->fSampleMax[ivar]=xmax;
00433 }
00434
00435
00436 void TMVA::DecisionTreeNode::ReadAttributes(void* node, UInt_t )
00437 {
00438 Float_t tempNSigEvents,tempNBkgEvents;
00439
00440 Int_t nCoef;
00441 if (gTools().HasAttr(node, "NCoef")){
00442 gTools().ReadAttr(node, "NCoef", nCoef );
00443 this->SetNFisherCoeff(nCoef);
00444 Double_t tmp;
00445 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {
00446 gTools().ReadAttr(node, Form("fC%d",i), tmp );
00447 this->SetFisherCoeff(i,tmp);
00448 }
00449 }else{
00450 this->SetNFisherCoeff(0);
00451 }
00452 gTools().ReadAttr(node, "IVar", fSelector );
00453 gTools().ReadAttr(node, "Cut", fCutValue );
00454 gTools().ReadAttr(node, "cType", fCutType );
00455 if (gTools().HasAttr(node,"res")) gTools().ReadAttr(node, "res", fResponse);
00456 if (gTools().HasAttr(node,"rms")) gTools().ReadAttr(node, "rms", fRMS);
00457
00458 if( gTools().HasAttr(node, "purity") ) {
00459 gTools().ReadAttr(node, "purity",fPurity );
00460 } else {
00461 gTools().ReadAttr(node, "nS", tempNSigEvents );
00462 gTools().ReadAttr(node, "nB", tempNBkgEvents );
00463 fPurity = tempNSigEvents / (tempNSigEvents + tempNBkgEvents);
00464 }
00465
00466 gTools().ReadAttr(node, "nType", fNodeType );
00467 }
00468
00469
00470
00471 void TMVA::DecisionTreeNode::AddAttributesToNode(void* node) const
00472 {
00473
00474 gTools().AddAttr(node, "NCoef", GetNFisherCoeff());
00475 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++)
00476 gTools().AddAttr(node, Form("fC%d",i), this->GetFisherCoeff(i));
00477
00478 gTools().AddAttr(node, "IVar", GetSelector());
00479 gTools().AddAttr(node, "Cut", GetCutValue());
00480 gTools().AddAttr(node, "cType", GetCutType());
00481
00482
00483
00484 gTools().AddAttr(node, "res", GetResponse());
00485 gTools().AddAttr(node, "rms", GetRMS());
00486
00487 gTools().AddAttr(node, "purity",GetPurity());
00488
00489 gTools().AddAttr(node, "nType", GetNodeType());
00490 }
00491
00492
00493 void TMVA::DecisionTreeNode::SetFisherCoeff(Int_t ivar, Double_t coeff)
00494 {
00495
00496 if ((Int_t) fFisherCoeff.size()<ivar+1) fFisherCoeff.resize(ivar+1) ;
00497 fFisherCoeff[ivar]=coeff;
00498 }
00499
00500
00501 void TMVA::DecisionTreeNode::AddContentToNode( std::stringstream& ) const
00502 {
00503
00504
00505
00506 }
00507
00508
00509 void TMVA::DecisionTreeNode::ReadContent( std::stringstream& )
00510 {
00511
00512
00513
00514 }