00001 // @(#)root/tmva $Id: Rule.h 29195 2009-06-24 10:39:49Z brun $ 00002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss 00003 00004 /********************************************************************************** 00005 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 00006 * Package: TMVA * 00007 * Class : Rule * 00008 * * 00009 * Description: * 00010 * A class describung a 'rule' * 00011 * Each internal node of a tree defines a rule from all the parental nodes. * 00012 * A rule consists of atleast 2 nodes. * 00013 * Input: a decision tree (in the constructor) * 00014 * its coefficient * 00015 * * 00016 * * 00017 * Authors (alphabetical): * 00018 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA * 00019 * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. * 00020 * * 00021 * Copyright (c) 2005: * 00022 * CERN, Switzerland * 00023 * Iowa State U. * 00024 * MPI-K Heidelberg, Germany * 00025 * * 00026 * Redistribution and use in source and binary forms, with or without * 00027 * modification, are permitted according to the terms listed in LICENSE * 00028 * (http://tmva.sourceforge.net/LICENSE) * 00029 **********************************************************************************/ 00030 00031 #ifndef ROOT_TMVA_Rule 00032 #define ROOT_TMVA_Rule 00033 00034 #ifndef ROOT_TMath 00035 #include "TMath.h" 00036 #endif 00037 00038 #ifndef ROOT_TMVA_DecisionTree 00039 #include "TMVA/DecisionTree.h" 00040 #endif 00041 #ifndef ROOT_TMVA_Event 00042 #include "TMVA/Event.h" 00043 #endif 00044 #ifndef ROOT_TMVA_RuleCut 00045 #include "TMVA/RuleCut.h" 00046 #endif 00047 00048 namespace TMVA { 00049 00050 class RuleEnsemble; 00051 class MsgLogger; 00052 class Rule; 00053 00054 ostream& operator<<( ostream& os, const Rule & rule ); 00055 00056 class Rule { 00057 00058 // ouput operator for a Rule 00059 friend ostream& operator<< ( ostream& os, const Rule & rule ); 00060 00061 public: 00062 00063 // main constructor 00064 Rule( RuleEnsemble *re, const std::vector< const TMVA::Node * > & nodes ); 00065 00066 // main constructor 00067 Rule( RuleEnsemble *re ); 00068 00069 // copy constructor 00070 Rule( const Rule & other ) { Copy( other ); } 00071 00072 // empty constructor 00073 Rule(); 00074 00075 virtual ~Rule(); 00076 00077 // set message type 00078 void SetMsgType( EMsgType t ); 00079 00080 // set RuleEnsemble ptr 00081 void SetRuleEnsemble( const RuleEnsemble *re ) { fRuleEnsemble = re; } 00082 00083 // set RuleCut ptr 00084 void SetRuleCut( RuleCut *rc ) { fCut = rc; } 00085 00086 // set Rule norm 00087 void SetNorm(Double_t norm) { fNorm = (norm>0 ? 1.0/norm:1.0); } 00088 00089 // set coefficient 00090 void SetCoefficient(Double_t v) { fCoefficient=v; } 00091 00092 // set support 00093 void SetSupport(Double_t v) { fSupport=v; fSigma = TMath::Sqrt(v*(1.0-v));} 00094 00095 // set s/(s+b) 00096 void SetSSB(Double_t v) { fSSB=v; } 00097 00098 // set N(eve) accepted by rule 00099 void SetSSBNeve(Double_t v) { fSSBNeve=v; } 00100 00101 // set reference importance 00102 void SetImportanceRef(Double_t v) { fImportanceRef=(v>0 ? v:1.0); } 00103 00104 // calculate importance 00105 void CalcImportance() { fImportance = TMath::Abs(fCoefficient)*fSigma; } 00106 00107 // get the relative importance 00108 Double_t GetRelImportance() const { return fImportance/fImportanceRef; } 00109 00110 // evaluate the Rule for the given Event using the coefficient 00111 // inline Double_t EvalEvent( const Event& e, Bool_t norm ) const; 00112 00113 // evaluate the Rule for the given Event, not using normalization or the coefficent 00114 inline Bool_t EvalEvent( const Event& e ) const; 00115 00116 // test if two rules are equal 00117 Bool_t Equal( const Rule & other, Bool_t useCutValue, Double_t maxdist ) const; 00118 00119 // get distance between two equal (ie apart from the cut values) rules 00120 Double_t RuleDist( const Rule & other, Bool_t useCutValue ) const; 00121 00122 // returns true if the trained S/(S+B) of the last node is > 0.5 00123 Double_t GetSSB() const { return fSSB; } 00124 Double_t GetSSBNeve() const { return fSSBNeve; } 00125 Bool_t IsSignalRule() const { return (fSSB>0.5); } 00126 00127 // copy operator 00128 void operator=( const Rule & other ) { Copy( other ); } 00129 00130 // identical operator 00131 Bool_t operator==( const Rule & other ) const; 00132 00133 Bool_t operator<( const Rule & other ) const; 00134 00135 // get number of variables used in Rule 00136 UInt_t GetNumVarsUsed() const { return fCut->GetNvars(); } 00137 00138 // get number of cuts in Rule 00139 UInt_t GetNcuts() const { return fCut->GetNcuts(); } 00140 00141 // check if variable is used by the rule 00142 Bool_t ContainsVariable(UInt_t iv) const; 00143 00144 // accessors 00145 const RuleCut* GetRuleCut() const { return fCut; } 00146 const RuleEnsemble* GetRuleEnsemble() const { return fRuleEnsemble; } 00147 Double_t GetCoefficient() const { return fCoefficient; } 00148 Double_t GetSupport() const { return fSupport; } 00149 Double_t GetSigma() const { return fSigma; } 00150 Double_t GetNorm() const { return fNorm; } 00151 Double_t GetImportance() const { return fImportance; } 00152 Double_t GetImportanceRef() const { return fImportanceRef; } 00153 00154 // print the rule using flogger 00155 void PrintLogger( const char *title=0 ) const; 00156 00157 // print just the raw info, used for weight file generation 00158 void PrintRaw ( ostream& os ) const; // obsolete 00159 void* AddXMLTo ( void* parent ) const; 00160 00161 void ReadRaw ( istream& os ); // obsolete 00162 void ReadFromXML( void* wghtnode ); 00163 00164 private: 00165 00166 // set sigma - don't use this as non private! 00167 void SetSigma(Double_t v) { fSigma=v; } 00168 00169 // print info about the Rule 00170 void Print( ostream& os ) const; 00171 00172 // copy from another rule 00173 void Copy( const Rule & other ); 00174 00175 // get the name of variable with index i 00176 const TString & GetVarName( Int_t i) const; 00177 00178 RuleCut* fCut; // all cuts associated with the rule 00179 Double_t fNorm; // normalization - usually 1.0/t(k) 00180 Double_t fSupport; // s(k) 00181 Double_t fSigma; // t(k) = sqrt(s*(1-s)) 00182 Double_t fCoefficient; // rule coeff. a(k) 00183 Double_t fImportance; // importance of rule 00184 Double_t fImportanceRef; // importance ref 00185 const RuleEnsemble* fRuleEnsemble; // pointer to parent RuleEnsemble 00186 Double_t fSSB; // S/(S+B) for rule 00187 Double_t fSSBNeve; // N(events) reaching the last node in reevaluation 00188 00189 mutable MsgLogger* fLogger; //! message logger 00190 MsgLogger& Log() const { return *fLogger; } 00191 00192 }; 00193 00194 } // end of TMVA namespace 00195 00196 //_______________________________________________________________________ 00197 inline Bool_t TMVA::Rule::EvalEvent( const TMVA::Event& e ) const 00198 { 00199 // Checks if event is accepted by rule. 00200 // Return true if yes and false if not. 00201 // 00202 return fCut->EvalEvent(e); 00203 } 00204 00205 #endif