Rule.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: Rule.cxx 36134 2010-10-06 18:29:59Z stelzer $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : Rule                                                                  *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      A class describung a 'rule'                                               *
00012  *      Each internal node of a tree defines a rule from all the parental nodes.  *
00013  *      A rule consists of atleast 2 nodes.                                       *
00014  *      Input: a decision tree (in the constructor)                               *
00015  *                                                                                *
00016  * Authors (alphabetical):                                                        *
00017  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
00018  *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, Ger. *
00019  *                                                                                *
00020  * Copyright (c) 2005:                                                            *
00021  *      CERN, Switzerland                                                         *
00022  *      Iowa State U.                                                             *
00023  *      MPI-K Heidelberg, Germany                                                 *
00024  *                                                                                *
00025  * Redistribution and use in source and binary forms, with or without             *
00026  * modification, are permitted according to the terms listed in LICENSE           *
00027  * (http://tmva.sourceforge.net/LICENSE)                                          *
00028  **********************************************************************************/
00029 
00030 //________________________________________________________________________________
00031 //
00032 // Implementation of a rule
00033 //
00034 // A rule is simply a branch or a part of a branch in a tree.
00035 // It fullfills the following:
00036 // * First node is the root node of the originating tree
00037 // * Consists of a minimum of 2 nodes
00038 // * A rule returns for a given event:
00039 //    0 : if the event fails at any node
00040 //    1 : otherwise
00041 // * If the rule contains <2 nodes, it returns 0 SHOULD NOT HAPPEN!
00042 //
00043 // The coefficient is found by either brute force or some sort of
00044 // intelligent fitting. See the RuleEnsemble class for more info.
00045 //________________________________________________________________________________
00046 
00047 #include "TMVA/Event.h"
00048 #include "TMVA/RuleCut.h"
00049 #include "TMVA/Rule.h"
00050 #include "TMVA/RuleFit.h"
00051 #include "TMVA/RuleEnsemble.h"
00052 #include "TMVA/MethodRuleFit.h"
00053 #include "TMVA/Tools.h"
00054 
00055 //_______________________________________________________________________
00056 TMVA::Rule::Rule( RuleEnsemble *re,
00057                   const std::vector< const Node * >& nodes )
00058    : fCut           ( 0 )
00059    , fNorm          ( 1.0 )
00060    , fSupport       ( 0.0 )
00061    , fSigma         ( 0.0 )
00062    , fCoefficient   ( 0.0 )
00063    , fImportance    ( 0.0 )
00064    , fImportanceRef ( 1.0 )
00065    , fRuleEnsemble  ( re )
00066    , fSSB           ( 0 )
00067    , fSSBNeve       ( 0 )
00068    , fLogger( new MsgLogger("RuleFit") )
00069 {
00070    // the main constructor for a Rule
00071 
00072    //
00073    // input:
00074    //   nodes  - a vector of Node; from these all possible rules will be created
00075    //
00076    //
00077 
00078    fCut     = new RuleCut( nodes );
00079    fSSB     = fCut->GetPurity();
00080    fSSBNeve = fCut->GetCutNeve();
00081 }
00082 
00083 //_______________________________________________________________________
00084 TMVA::Rule::Rule( RuleEnsemble *re )
00085    : fCut           ( 0 )
00086    , fNorm          ( 1.0 )
00087    , fSupport       ( 0.0 )
00088    , fSigma         ( 0.0 )
00089    , fCoefficient   ( 0.0 )
00090    , fImportance    ( 0.0 )
00091    , fImportanceRef ( 1.0 )
00092    , fRuleEnsemble  ( re )
00093    , fSSB           ( 0 )
00094    , fSSBNeve       ( 0 )
00095    , fLogger( new MsgLogger("RuleFit") )
00096 {
00097    // the simple constructor
00098 }
00099 
00100 //_______________________________________________________________________
00101 TMVA::Rule::Rule()
00102    : fCut           ( 0 )
00103    , fNorm          ( 1.0 )
00104    , fSupport       ( 0.0 )
00105    , fSigma         ( 0.0 )
00106    , fCoefficient   ( 0.0 )
00107    , fImportance    ( 0.0 )
00108    , fImportanceRef ( 1.0 )
00109    , fRuleEnsemble  ( 0 )
00110    , fSSB           ( 0 )
00111    , fSSBNeve       ( 0 )
00112    , fLogger( new MsgLogger("RuleFit") )
00113 {
00114    // the simple constructor
00115 }
00116 
00117 //_______________________________________________________________________
00118 TMVA::Rule::~Rule() 
00119 {
00120    // destructor
00121    delete fCut;
00122    delete fLogger;
00123 }
00124 
00125 //_______________________________________________________________________
00126 Bool_t TMVA::Rule::ContainsVariable(UInt_t iv) const
00127 {
00128    // check if variable in node
00129    Bool_t found    = kFALSE;
00130    Bool_t doneLoop = kFALSE;
00131    UInt_t nvars    = fCut->GetNvars();
00132    UInt_t i        = 0;
00133    //
00134    while (!doneLoop) {
00135       found = (fCut->GetSelector(i) == iv);
00136       i++;
00137       doneLoop = (found || (i==nvars));
00138    }
00139    return found;
00140 }
00141 
00142 //_______________________________________________________________________
00143 void TMVA::Rule::SetMsgType( EMsgType t ) 
00144 {
00145    fLogger->SetMinType(t);
00146 }
00147 
00148 
00149 //_______________________________________________________________________
00150 Bool_t TMVA::Rule::Equal( const Rule& other, Bool_t useCutValue, Double_t mindist ) const
00151 {
00152    //
00153    // Compare two rules.
00154    // useCutValue: true -> calculate a distance between the two rules based on the cut values
00155    //                      if the rule cuts are not equal, the distance is < 0 (-1.0)
00156    //                      return true if d<mindist
00157    //              false-> ignore mindist, return true if rules are equal, ignoring cut values
00158    // mindist:     min distance allowed between rules; if < 0 => set useCutValue=false;
00159    //
00160    Bool_t rval=kFALSE;
00161    if (mindist<0) useCutValue=kFALSE;
00162    Double_t d = RuleDist( other, useCutValue );
00163    // cut value used - return true if 0<=d<mindist
00164    if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
00165    else rval = (!(d<0));
00166    // cut value not used, return true if <> -1
00167    return rval;
00168 }
00169 
00170 //_______________________________________________________________________
00171 Double_t TMVA::Rule::RuleDist( const Rule& other, Bool_t useCutValue ) const
00172 {
00173    // Returns:
00174    // -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong
00175    //   >=0: rules are equal apart from the cutvalue, returns d = sqrt(sum(c1-c2)^2)
00176    // If not useCutValue, the distance is exactly zero if they are equal
00177    //
00178    if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars()) return -1.0; // check number of cuts
00179    //
00180    const UInt_t nvars  = fCut->GetNvars();
00181    //
00182    Int_t    sel;         // cut variable
00183    Double_t rms;         // rms of cut variable
00184    Double_t smin;        // distance between the lower range
00185    Double_t smax;        // distance between the upper range
00186    Double_t vminA,vmaxA; // min,max range of cut A (cut from this Rule)
00187    Double_t vminB,vmaxB; // idem from other Rule
00188    //
00189    // compare nodes
00190    // A 'distance' is assigned if the two rules has exactly the same set of cuts but with
00191    // different cut values.
00192    // The distance is given in number of sigmas
00193    //
00194    UInt_t   in     = 0;    // cut index
00195    Double_t sumdc2 = 0;    // sum of 'distances'
00196    Bool_t   equal  = true; // flag if cut are equal
00197    //
00198    const RuleCut *otherCut = other.GetRuleCut();
00199    while ((equal) && (in<nvars)) {
00200       // check equality in cut topology
00201       equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
00202                 (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
00203                 (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
00204       // if equal topology, check cut values
00205       if (equal) {
00206          if (useCutValue) {
00207             sel   = fCut->GetSelector(in);
00208             vminA = fCut->GetCutMin(in);
00209             vmaxA = fCut->GetCutMax(in);
00210             vminB = other.GetRuleCut()->GetCutMin(in);
00211             vmaxB = other.GetRuleCut()->GetCutMax(in);
00212             // messy - but ok...
00213             rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
00214             smin=0;
00215             smax=0;
00216             if (fCut->GetCutDoMin(in))
00217                smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
00218             if (fCut->GetCutDoMax(in))
00219                smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
00220             sumdc2 += smin*smin + smax*smax;
00221             //            sumw   += 1.0/(rms*rms); // TODO: probably not needed
00222          }
00223       }
00224       in++;
00225    }
00226    if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0); // ignore cut values
00227    else              sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
00228 
00229    return sumdc2;
00230 }
00231 
00232 //_______________________________________________________________________
00233 Bool_t TMVA::Rule::operator==( const Rule& other ) const
00234 {
00235    // comparison operator ==
00236 
00237    return this->Equal( other, kTRUE, 1e-3 );
00238 }
00239 
00240 //_______________________________________________________________________
00241 Bool_t TMVA::Rule::operator<( const Rule& other ) const
00242 {
00243    // comparison operator <
00244    return (fImportance < other.GetImportance());
00245 }
00246 
00247 //_______________________________________________________________________
00248 ostream& TMVA::operator<< ( ostream& os, const Rule& rule )
00249 {
00250    // ostream operator
00251    rule.Print( os );
00252    return os;
00253 }
00254 
00255 //_______________________________________________________________________
00256 const TString & TMVA::Rule::GetVarName( Int_t i ) const
00257 {
00258    // returns the name of a rule
00259 
00260    return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
00261 }
00262 
00263 //_______________________________________________________________________
00264 void TMVA::Rule::Copy( const Rule& other )
00265 {
00266    // copy function
00267    if(this != &other) {
00268       SetRuleEnsemble( other.GetRuleEnsemble() );
00269       fCut = new RuleCut( *(other.GetRuleCut()) );
00270       fSSB     = other.GetSSB();
00271       fSSBNeve = other.GetSSBNeve();
00272       SetCoefficient(other.GetCoefficient());
00273       SetSupport( other.GetSupport() );
00274       SetSigma( other.GetSigma() );
00275       SetNorm( other.GetNorm() );
00276       CalcImportance();
00277       SetImportanceRef( other.GetImportanceRef() );
00278    }
00279 }
00280 
00281 //_______________________________________________________________________
00282 void TMVA::Rule::Print( ostream& os ) const
00283 {
00284    // print function
00285    const UInt_t nvars = fCut->GetNvars();
00286    if (nvars<1) os << "     *** WARNING - <EMPTY RULE> ***" << std::endl; // TODO: Fix this, use fLogger
00287    //
00288    Int_t sel;
00289    Double_t valmin, valmax;
00290    //
00291    os << "    Importance  = " << Form("%1.4f", fImportance/fImportanceRef) << std::endl;
00292    os << "    Coefficient = " << Form("%1.4f", fCoefficient) << std::endl;
00293    os << "    Support     = " << Form("%1.4f", fSupport)  << std::endl;
00294    os << "    S/(S+B)     = " << Form("%1.4f", fSSB)  << std::endl;  
00295 
00296    for ( UInt_t i=0; i<nvars; i++) {
00297       os << "    ";
00298       sel    = fCut->GetSelector(i);
00299       valmin = fCut->GetCutMin(i);
00300       valmax = fCut->GetCutMax(i);
00301       //
00302       os << Form("* Cut %2d",i+1) << " : " << std::flush;
00303       if (fCut->GetCutDoMin(i)) os << Form("%10.3g",valmin) << " < " << std::flush;
00304       else                      os << "             " << std::flush;
00305       os << GetVarName(sel) << std::flush;
00306       if (fCut->GetCutDoMax(i)) os << " < " << Form("%10.3g",valmax) << std::flush;
00307       else                      os << "             " << std::flush;
00308       os << std::endl;
00309    }
00310 }
00311 
00312 //_______________________________________________________________________
00313 void TMVA::Rule::PrintLogger(const char *title) const
00314 {
00315    // print function
00316    const UInt_t nvars = fCut->GetNvars();
00317    if (nvars<1) Log() << kWARNING << "BUG TRAP: EMPTY RULE!!!" << Endl;
00318    //
00319    Int_t sel;
00320    Double_t valmin, valmax;
00321    //
00322    if (title) Log() << kINFO << title;
00323    Log() << kINFO
00324            << "Importance  = " << Form("%1.4f", fImportance/fImportanceRef) << Endl;
00325 
00326    for ( UInt_t i=0; i<nvars; i++) {
00327       
00328       Log() << kINFO << "            ";
00329       sel    = fCut->GetSelector(i);
00330       valmin = fCut->GetCutMin(i);
00331       valmax = fCut->GetCutMax(i);
00332       //
00333       Log() << kINFO << Form("Cut %2d",i+1) << " : ";
00334       if (fCut->GetCutDoMin(i)) Log() << kINFO << Form("%10.3g",valmin) << " < ";
00335       else                      Log() << kINFO << "             ";
00336       Log() << kINFO << GetVarName(sel);
00337       if (fCut->GetCutDoMax(i)) Log() << kINFO << " < " << Form("%10.3g",valmax);
00338       else                      Log() << kINFO << "             ";
00339       Log() << Endl;
00340    }
00341 }
00342 
00343 //_______________________________________________________________________
00344 void TMVA::Rule::PrintRaw( ostream& os ) const
00345 {
00346    // extensive print function used to print info for the weight file
00347    Int_t dp = os.precision();
00348    const UInt_t nvars = fCut->GetNvars();
00349    os << "Parameters: "
00350       << std::setprecision(10)
00351       << fImportance << " "
00352       << fImportanceRef << " "
00353       << fCoefficient << " "
00354       << fSupport << " "
00355       << fSigma << " "
00356       << fNorm << " "
00357       << fSSB << " "
00358       << fSSBNeve << " "
00359       << std::endl;                                         \
00360    os << "N(cuts): " << nvars << std::endl; // mark end of nodes
00361    for ( UInt_t i=0; i<nvars; i++) {
00362       os << "Cut " << i << " : " << std::flush;
00363       os <<        fCut->GetSelector(i)
00364          << std::setprecision(10)
00365          << " " << fCut->GetCutMin(i)
00366          << " " << fCut->GetCutMax(i)
00367          << " " << (fCut->GetCutDoMin(i) ? "T":"F")
00368          << " " << (fCut->GetCutDoMax(i) ? "T":"F")
00369          << std::endl;
00370    }
00371    os << std::setprecision(dp);
00372 }
00373 
00374 //_______________________________________________________________________
00375 void* TMVA::Rule::AddXMLTo( void* parent ) const 
00376 {
00377    void* rule = gTools().AddChild( parent, "Rule" );
00378    const UInt_t nvars = fCut->GetNvars();
00379 
00380    gTools().AddAttr( rule, "Importance", fImportance    );
00381    gTools().AddAttr( rule, "Ref",        fImportanceRef );
00382    gTools().AddAttr( rule, "Coeff",      fCoefficient   );
00383    gTools().AddAttr( rule, "Support",    fSupport       );
00384    gTools().AddAttr( rule, "Sigma",      fSigma         );
00385    gTools().AddAttr( rule, "Norm",       fNorm          );
00386    gTools().AddAttr( rule, "SSB",        fSSB           );
00387    gTools().AddAttr( rule, "SSBNeve",    fSSBNeve       );
00388    gTools().AddAttr( rule, "Nvars",      nvars          );
00389 
00390    for (UInt_t i=0; i<nvars; i++) {
00391       void* cut = gTools().AddChild( rule, "Cut" );
00392       gTools().AddAttr( cut, "Selector", fCut->GetSelector(i) );
00393       gTools().AddAttr( cut, "Min",      fCut->GetCutMin(i) );
00394       gTools().AddAttr( cut, "Max",      fCut->GetCutMax(i) );
00395       gTools().AddAttr( cut, "DoMin",    (fCut->GetCutDoMin(i) ? "T":"F") );
00396       gTools().AddAttr( cut, "DoMax",    (fCut->GetCutDoMax(i) ? "T":"F") );
00397    }
00398 
00399    return rule;
00400 }
00401 
00402 //_______________________________________________________________________
00403 void TMVA::Rule::ReadFromXML( void* wghtnode )
00404 {
00405    // read rule from XML
00406    TString nodeName = TString( gTools().GetName(wghtnode) );
00407    if (nodeName != "Rule") Log() << kFATAL << "<ReadFromXML> Unexpected node name: " << nodeName << Endl;
00408 
00409    gTools().ReadAttr( wghtnode, "Importance", fImportance    );
00410    gTools().ReadAttr( wghtnode, "Ref",        fImportanceRef );
00411    gTools().ReadAttr( wghtnode, "Coeff",      fCoefficient   );
00412    gTools().ReadAttr( wghtnode, "Support",    fSupport       );
00413    gTools().ReadAttr( wghtnode, "Sigma",      fSigma         );
00414    gTools().ReadAttr( wghtnode, "Norm",       fNorm          );
00415    gTools().ReadAttr( wghtnode, "SSB",        fSSB           );
00416    gTools().ReadAttr( wghtnode, "SSBNeve",    fSSBNeve       );
00417 
00418    UInt_t nvars;
00419    gTools().ReadAttr( wghtnode, "Nvars",      nvars          );
00420    if (fCut) delete fCut;
00421    fCut = new RuleCut();
00422    fCut->SetNvars( nvars );
00423 
00424    // read Cut
00425    void*    ch = gTools().GetChild( wghtnode );
00426    UInt_t   i = 0;
00427    UInt_t   ui;
00428    Double_t d;
00429    Char_t   c;
00430    while (ch) {
00431       gTools().ReadAttr( ch, "Selector", ui );
00432       fCut->SetSelector( i, ui );
00433       gTools().ReadAttr( ch, "Min",      d );
00434       fCut->SetCutMin  ( i, d );
00435       gTools().ReadAttr( ch, "Max",      d );
00436       fCut->SetCutMax  ( i, d );
00437       gTools().ReadAttr( ch, "DoMin",    c );
00438       fCut->SetCutDoMin( i, (c == 'T' ? kTRUE : kFALSE ) );
00439       gTools().ReadAttr( ch, "DoMax",    c );
00440       fCut->SetCutDoMax( i, (c == 'T' ? kTRUE : kFALSE ) );
00441 
00442       i++;
00443       ch = gTools().GetNextChild(ch);
00444    }
00445 
00446    // sanity check
00447    if (i != nvars) Log() << kFATAL << "<ReadFromXML> Mismatch in number of cuts: " << i << " != " << nvars << Endl;
00448 }
00449 
00450 //_______________________________________________________________________
00451 void TMVA::Rule::ReadRaw( istream& istr )
00452 {
00453    // read function (format is the same as written by PrintRaw)
00454 
00455    TString dummy;
00456    UInt_t nvars;
00457    istr >> dummy
00458         >> fImportance
00459         >> fImportanceRef
00460         >> fCoefficient
00461         >> fSupport
00462         >> fSigma
00463         >> fNorm
00464         >> fSSB
00465         >> fSSBNeve;
00466    // coverity[tainted_data_argument]
00467    istr >> dummy >> nvars;
00468    Double_t cutmin,cutmax;
00469    UInt_t   sel,idum;
00470    Char_t   bA, bB;
00471    //
00472    if (fCut) delete fCut;
00473    fCut = new RuleCut();
00474    fCut->SetNvars( nvars );
00475    for ( UInt_t i=0; i<nvars; i++) {
00476       istr >> dummy >> idum; // get 'Node' and index
00477       istr >> dummy;         // get ':'
00478       istr >> sel >> cutmin >> cutmax >> bA >> bB;
00479       fCut->SetSelector(i,sel);
00480       fCut->SetCutMin(i,cutmin);
00481       fCut->SetCutMax(i,cutmax);
00482       fCut->SetCutDoMin(i,(bA=='T' ? kTRUE:kFALSE));
00483       fCut->SetCutDoMax(i,(bB=='T' ? kTRUE:kFALSE));
00484    }
00485 }

Generated on Tue Jul 5 15:25:35 2011 for ROOT_528-00b_version by  doxygen 1.5.1