CCTreeWrapper.cxx

Go to the documentation of this file.
00001 /**********************************************************************************
00002  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00003  * Package: TMVA                                                                  *
00004  * Class  : CCTreeWrapper                                                         *
00005  * Web    : http://tmva.sourceforge.net                                           *
00006  *                                                                                *
00007  * Description: a light wrapper of a decision tree, used to perform cost          *
00008  *              complexity pruning "in-place" Cost Complexity Pruning             *
00009  *                                                                                *  
00010  * Author: Doug Schouten (dschoute@sfu.ca)                                        *
00011  *                                                                                *
00012  *                                                                                *
00013  * Copyright (c) 2007:                                                            *
00014  *      CERN, Switzerland                                                         *
00015  *      MPI-K Heidelberg, Germany                                                 *
00016  *      U. of Texas at Austin, USA                                                *
00017  *                                                                                *
00018  * Redistribution and use in source and binary forms, with or without             *
00019  * modification, are permitted according to the terms listed in LICENSE           *
00020  * (http://tmva.sourceforge.net/LICENSE)                                          *
00021  **********************************************************************************/
00022 
00023 #include "TMVA/CCTreeWrapper.h"
00024 
00025 #include <iostream>
00026 #include <limits>
00027 
00028 using namespace TMVA;
00029 
00030 //_______________________________________________________________________
00031 TMVA::CCTreeWrapper::CCTreeNode::CCTreeNode( DecisionTreeNode* n ) :
00032    Node(),
00033    fNLeafDaughters(0),
00034    fNodeResubstitutionEstimate(-1.0),
00035    fResubstitutionEstimate(-1.0),
00036    fAlphaC(-1.0),
00037    fMinAlphaC(-1.0),
00038    fDTNode(n)
00039 {
00040    //constructor of the CCTreeNode
00041    if ( n != NULL && n->GetRight() != NULL && n->GetLeft() != NULL ) {
00042       SetRight( new CCTreeNode( ((DecisionTreeNode*) n->GetRight()) ) );
00043       GetRight()->SetParent(this);
00044       SetLeft( new CCTreeNode( ((DecisionTreeNode*) n->GetLeft()) ) );
00045       GetLeft()->SetParent(this);
00046    }
00047 }
00048 
00049 //_______________________________________________________________________
00050 TMVA::CCTreeWrapper::CCTreeNode::~CCTreeNode() {
00051    // destructor of a CCTreeNode
00052 
00053    if(GetLeft() != NULL) delete GetLeftDaughter();
00054    if(GetRight() != NULL) delete GetRightDaughter();
00055 }
00056 
00057 //_______________________________________________________________________
00058 Bool_t TMVA::CCTreeWrapper::CCTreeNode::ReadDataRecord( std::istream& in, UInt_t /* tmva_Version_Code */ ) {
00059    // initialize a node from a data record
00060    
00061    std::string header, title;
00062    in >> header;
00063    in >> title; in >> fNLeafDaughters;
00064    in >> title; in >> fNodeResubstitutionEstimate;
00065    in >> title; in >> fResubstitutionEstimate;
00066    in >> title; in >> fAlphaC;
00067    in >> title; in >> fMinAlphaC;
00068    return true;
00069 }
00070 
00071 //_______________________________________________________________________
00072 void TMVA::CCTreeWrapper::CCTreeNode::Print( ostream& os ) const {
00073    // printout of the node (can be read in with ReadDataRecord)
00074 
00075    os << "----------------------" << std::endl 
00076       << "|~T_t| " << fNLeafDaughters << std::endl 
00077       << "R(t): " << fNodeResubstitutionEstimate << std::endl 
00078       << "R(T_t): " << fResubstitutionEstimate << std::endl
00079       << "g(t): " << fAlphaC << std::endl
00080       << "G(t): " << fMinAlphaC << std::endl;
00081 }
00082 
00083 //_______________________________________________________________________
00084 void TMVA::CCTreeWrapper::CCTreeNode::PrintRec( ostream& os ) const {
00085    // recursive printout of the node and its daughters 
00086 
00087    this->Print(os);
00088    if(this->GetLeft() != NULL && this->GetRight() != NULL) {
00089       this->GetLeft()->PrintRec(os);
00090       this->GetRight()->PrintRec(os);
00091    }
00092 }
00093 
00094 //_______________________________________________________________________
00095 TMVA::CCTreeWrapper::CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex ) :
00096    fRoot(NULL)
00097 {
00098    // constructor
00099 
00100    fDTParent = T;
00101    fRoot = new CCTreeNode( dynamic_cast<DecisionTreeNode*>(T->GetRoot()) );
00102    fQualityIndex = qualityIndex;
00103    InitTree(fRoot);
00104 }
00105   
00106 //_______________________________________________________________________
00107 TMVA::CCTreeWrapper::~CCTreeWrapper( ) {
00108    // destructor
00109 
00110    delete fRoot; 
00111 }  
00112 
00113 //_______________________________________________________________________
00114 void TMVA::CCTreeWrapper::InitTree( CCTreeNode* t )
00115 {
00116     // initialize the node t and all its descendants
00117    Double_t s = t->GetDTNode()->GetNSigEvents();
00118    Double_t b = t->GetDTNode()->GetNBkgEvents();
00119    //   Double_t s = t->GetDTNode()->GetNSigEvents_unweighted();
00120    //   Double_t b = t->GetDTNode()->GetNBkgEvents_unweighted();
00121    // set R(t) = Gini(t) or MisclassificationError(t), etc.
00122    t->SetNodeResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
00123 
00124    if(t->GetLeft() != NULL && t->GetRight() != NULL) { // n is an interior (non-leaf) node
00125       // traverse the tree 
00126       InitTree(t->GetLeftDaughter());
00127       InitTree(t->GetRightDaughter());
00128       // set |~T_t|
00129       t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + 
00130                            t->GetRightDaughter()->GetNLeafDaughters());    
00131       // set R(T) = sum[t' in ~T]{ R(t) }
00132       t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() +
00133                                    t->GetRightDaughter()->GetResubstitutionEstimate());
00134       // set g(t)
00135       t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate()) / 
00136                    (t->GetNLeafDaughters() - 1));
00137       // G(t) = min( g(t), G(l(t)), G(r(t)) )
00138       t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(), 
00139                                                         t->GetRightDaughter()->GetMinAlphaC())));
00140    }
00141    else { // n is a terminal node
00142       t->SetNLeafDaughters(1);
00143       t->SetResubstitutionEstimate((s+b)*fQualityIndex->GetSeparationIndex(s,b));
00144       t->SetAlphaC(std::numeric_limits<double>::infinity( ));
00145       t->SetMinAlphaC(std::numeric_limits<double>::infinity( ));
00146    }
00147 }
00148 
00149 //_______________________________________________________________________
00150 void TMVA::CCTreeWrapper::PruneNode( CCTreeNode* t )
00151 {
00152    // remove the branch rooted at node t
00153 
00154    if( t->GetLeft() != NULL &&
00155        t->GetRight() != NULL ) {
00156       CCTreeNode* l = t->GetLeftDaughter();
00157       CCTreeNode* r = t->GetRightDaughter();
00158       t->SetNLeafDaughters( 1 );
00159       t->SetResubstitutionEstimate( t->GetNodeResubstitutionEstimate() );
00160       t->SetAlphaC( std::numeric_limits<double>::infinity( ) );
00161       t->SetMinAlphaC( std::numeric_limits<double>::infinity( ) );
00162       delete l;
00163       delete r;
00164       t->SetLeft(NULL);
00165       t->SetRight(NULL);
00166    }else{
00167       std::cout << " ERROR in CCTreeWrapper::PruneNode: you try to prune a leaf node.. that does not make sense " << std::endl;
00168    }
00169 }
00170 
00171 //_______________________________________________________________________
00172 Double_t TMVA::CCTreeWrapper::TestTreeQuality( const EventList* validationSample )
00173 {
00174    // return the misclassification rate of a pruned tree for a validation event sample
00175    // using an EventList
00176 
00177    Double_t ncorrect=0, nfalse=0;
00178    for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
00179       Bool_t isSignalType = (CheckEvent(*(*validationSample)[ievt]) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
00180       
00181       if (isSignalType == ((*validationSample)[ievt]->GetClass() == 0)) {
00182          ncorrect += (*validationSample)[ievt]->GetWeight();
00183       }
00184       else{
00185          nfalse += (*validationSample)[ievt]->GetWeight();
00186       }
00187    }
00188    return  ncorrect / (ncorrect + nfalse);
00189 }
00190 
00191 //_______________________________________________________________________
00192 Double_t TMVA::CCTreeWrapper::TestTreeQuality( const DataSet* validationSample )
00193 {
00194    // return the misclassification rate of a pruned tree for a validation event sample
00195    // using the DataSet
00196 
00197    validationSample->SetCurrentType(Types::kValidation);
00198    // test the tree quality.. in terms of Miscalssification
00199    Double_t ncorrect=0, nfalse=0;
00200    for (Long64_t ievt=0; ievt<validationSample->GetNEvents(); ievt++){
00201       Event *ev = validationSample->GetEvent(ievt);
00202 
00203       Bool_t isSignalType = (CheckEvent(*ev) > fDTParent->GetNodePurityLimit() ) ? 1 : 0;
00204       
00205       if (isSignalType == (ev->GetClass() == 0)) {
00206          ncorrect += ev->GetWeight();
00207       }
00208       else{
00209          nfalse += ev->GetWeight();
00210       }
00211    }
00212    return  ncorrect / (ncorrect + nfalse);
00213 }
00214 
00215 //_______________________________________________________________________
00216 Double_t TMVA::CCTreeWrapper::CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf )
00217 {
00218    // return the decision tree output for an event 
00219 
00220    const DecisionTreeNode* current = fRoot->GetDTNode();
00221    CCTreeNode* t = fRoot;
00222 
00223    while(//current->GetNodeType() == 0 &&
00224          t->GetLeft() != NULL &&
00225          t->GetRight() != NULL){ // at an interior (non-leaf) node
00226       if (current->GoesRight(e)) {
00227          //current = (DecisionTreeNode*)current->GetRight();
00228          t = t->GetRightDaughter();
00229          current = t->GetDTNode();
00230       }
00231       else {
00232          //current = (DecisionTreeNode*)current->GetLeft();
00233          t = t->GetLeftDaughter();
00234          current = t->GetDTNode();
00235       }
00236    }
00237   
00238    if (useYesNoLeaf) return (current->GetPurity() > fDTParent->GetNodePurityLimit() ? 1.0 : -1.0);
00239    else return current->GetPurity();
00240 }
00241 
00242 //_______________________________________________________________________
00243 void TMVA::CCTreeWrapper::CCTreeNode::AddAttributesToNode( void* /*node*/ ) const
00244 {}
00245 
00246 //_______________________________________________________________________
00247 void TMVA::CCTreeWrapper::CCTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
00248 {}
00249 
00250 //_______________________________________________________________________
00251 void TMVA::CCTreeWrapper::CCTreeNode::ReadAttributes( void* /*node*/, UInt_t /* tmva_Version_Code */  )
00252 {}
00253 
00254 //_______________________________________________________________________
00255 void TMVA::CCTreeWrapper::CCTreeNode::ReadContent( std::stringstream& /*s*/ )
00256 {}

Generated on Tue Jul 5 15:16:41 2011 for ROOT_528-00b_version by  doxygen 1.5.1