00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #include "TMVA/CCTreeWrapper.h"
00024
00025 #include <iostream>
00026 #include <limits>
00027
00028 using namespace TMVA;
00029
00030
00031 TMVA::CCTreeWrapper::CCTreeNode::CCTreeNode( DecisionTreeNode* n ) :
00032 Node(),
00033 fNLeafDaughters(0),
00034 fNodeResubstitutionEstimate(-1.0),
00035 fResubstitutionEstimate(-1.0),
00036 fAlphaC(-1.0),
00037 fMinAlphaC(-1.0),
00038 fDTNode(n)
00039 {
00040
00041 if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
00042 SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
00043 GetRight()->SetParent(this);
00044 SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
00045 GetLeft()->SetParent(this);
00046 }
00047 }
00048
00049
00050 TMVA::CCTreeWrapper::CCTreeNode::~CCTreeNode() {
00051
00052
00053 if(GetLeft() != NULL) delete GetLeftDaughter();
00054 if(GetRight() != NULL) delete GetRightDaughter();
00055 }
00056
00057
00058 Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t ) {
00059
00060
00061 std::string header, title;
00062 in >> header;
00063 in >> title; in >> fNLeafDaughters;
00064 in >> title; in >> fNodeResubstitutionEstimate;
00065 in >> title; in >> fResubstitutionEstimate;
00066 in >> title; in >> fAlphaC;
00067 in >> title; in >> fMinAlphaC;
00068 return true;
00069 }
00070
00071
00072 void TMVA::CCTreeWrapper::CCTreeNode::Print( ostream& os ) const {
00073
00074
00075 os << "----------------------" << std::endl
00076 << "|~T_t| " << fNLeafDaughters << std::endl
00077 << "R(t): " << fNodeResubstitutionEstimate << std::endl
00078 << "R(T_t): " << fResubstitutionEstimate << std::endl
00079 << "g(t): " << fAlphaC << std::endl
00080 << "G(t): " << fMinAlphaC << std::endl;
00081 }
00082
00083
00084 void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( ostream& os ) const {
00085
00086
00087 this->Print(os);
00088 if(this->GetLeft() != NULL && this->GetRight() != NULL) {
00089 this->GetLeft()->PrintRec(os);
00090 this->GetRight()->PrintRec(os);
00091 }
00092 }
00093
00094
00095 TMVA::CCTreeWrapper::CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex ) :
00096 fRoot(NULL)
00097 {
00098
00099
00100 fDTParent = T;
00101 fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
00102 fQualityIndex = qualityIndex;
00103 InitTree(fRoot);
00104 }
00105
00106
00107 TMVA::CCTreeWrapper::~CCTreeWrapper( ) {
00108
00109
00110 delete fRoot;
00111 }
00112
00113
00114 void TMVA::CCTreeWrapper::InitTree( CCTreeNode* t )
00115 {
00116
00117 Double_t s = t->GetDTNode()->GetNSigEvents();
00118 Double_t b = t->GetDTNode()->GetNBkgEvents();
00119
00120
00121
00122 t->SetNodeResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
00123
00124 if(t->GetLeft() != NULL && t->GetRight() != NULL) {
00125
00126 InitTree(t->GetLeftDaughter());
00127 InitTree(t->GetRightDaughter());
00128
00129 t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() +
00130 t->GetRightDaughter()->GetNLeafDaughters());
00131
00132 t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() +
00133 t->GetRightDaughter()->GetResubstitutionEstimate());
00134
00135 t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate()) /
00136 (t->GetNLeafDaughters() - 1));
00137
00138 t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
00139 t->GetRightDaughter()->GetMinAlphaC())));
00140 }
00141 else {
00142 t->SetNLeafDaughters(1);
00143 t->SetResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
00144 t->SetAlphaC(std::numeric_limits<double>::infinity( ));
00145 t->SetMinAlphaC(std::numeric_limits<double>::infinity( ));
00146 }
00147 }
00148
00149
00150 void TMVA::CCTreeWrapper::PruneNode( CCTreeNode* t )
00151 {
00152
00153
00154 if( t->GetLeft() != NULL &&
00155 t->GetRight() != NULL ) {
00156 CCTreeNode* l = t->GetLeftDaughter();
00157 CCTreeNode* r = t->GetRightDaughter();
00158 t->SetNLeafDaughters( 1 );
00159 t->SetResubstitutionEstimate( t->GetNodeResubstitutionEstimate() );
00160 t->SetAlphaC( std::numeric_limits<double>::infinity( ) );
00161 t->SetMinAlphaC( std::numeric_limits<double>::infinity( ) );
00162 delete l;
00163 delete r;
00164 t->SetLeft(NULL);
00165 t->SetRight(NULL);
00166 }else{
00167 std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
00168 }
00169 }
00170
00171
00172 Double_t TMVA::CCTreeWrapper::TestTreeQuality( const EventList* validationSample )
00173 {
00174
00175
00176
00177 Double_t ncorrect=0, nfalse=0;
00178 for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
00179 Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
00180
00181 if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
00182 ncorrect += (*validationSample)[ievt]->GetWeight();
00183 }
00184 else{
00185 nfalse += (*validationSample)[ievt]->GetWeight();
00186 }
00187 }
00188 return ncorrect / (ncorrect + nfalse);
00189 }
00190
00191
00192 Double_t TMVA::CCTreeWrapper::TestTreeQuality( const DataSet* validationSample )
00193 {
00194
00195
00196
00197 validationSample->SetCurrentType(Types::kValidation);
00198
00199 Double_t ncorrect=0, nfalse=0;
00200 for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
00201 Event *ev = validationSample->GetEvent(ievt);
00202
00203 Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
00204
00205 if (isSignalType == (ev->GetClass() == 0)) {
00206 ncorrect += ev->GetWeight();
00207 }
00208 else{
00209 nfalse += ev->GetWeight();
00210 }
00211 }
00212 return ncorrect / (ncorrect + nfalse);
00213 }
00214
00215
00216 Double_t TMVA::CCTreeWrapper::CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf )
00217 {
00218
00219
00220 const DecisionTreeNode* current = fRoot->GetDTNode();
00221 CCTreeNode* t = fRoot;
00222
00223 while(
00224 t->GetLeft() != NULL &&
00225 t->GetRight() != NULL){
00226 if (current->GoesRight(e)) {
00227
00228 t = t->GetRightDaughter();
00229 current = t->GetDTNode();
00230 }
00231 else {
00232
00233 t = t->GetLeftDaughter();
00234 current = t->GetDTNode();
00235 }
00236 }
00237
00238 if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
00239 else return current->GetPurity();
00240 }
00241
00242
00243 void TMVA::CCTreeWrapper::CCTreeNode::AddAttributesToNode( void* ) const
00244 {}
00245
00246
00247 void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& ) const
00248 {}
00249
00250
00251 void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* , UInt_t )
00252 {}
00253
00254
00255 void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& )
00256 {}