00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047 #include "TMVA/Event.h"
00048 #include "TMVA/RuleCut.h"
00049 #include "TMVA/Rule.h"
00050 #include "TMVA/RuleFit.h"
00051 #include "TMVA/RuleEnsemble.h"
00052 #include "TMVA/MethodRuleFit.h"
00053 #include "TMVA/Tools.h"
00054
00055
00056 TMVA::Rule::Rule( RuleEnsemble *re,
00057 const std::vector< const Node * >& nodes )
00058 : fCut ( 0 )
00059 , fNorm ( 1.0 )
00060 , fSupport ( 0.0 )
00061 , fSigma ( 0.0 )
00062 , fCoefficient ( 0.0 )
00063 , fImportance ( 0.0 )
00064 , fImportanceRef ( 1.0 )
00065 , fRuleEnsemble ( re )
00066 , fSSB ( 0 )
00067 , fSSBNeve ( 0 )
00068 , fLogger( new MsgLogger("RuleFit") )
00069 {
00070
00071
00072
00073
00074
00075
00076
00077
00078 fCut = new RuleCut( nodes );
00079 fSSB = fCut->GetPurity();
00080 fSSBNeve = fCut->GetCutNeve();
00081 }
00082
00083
00084 TMVA::Rule::Rule( RuleEnsemble *re )
00085 : fCut ( 0 )
00086 , fNorm ( 1.0 )
00087 , fSupport ( 0.0 )
00088 , fSigma ( 0.0 )
00089 , fCoefficient ( 0.0 )
00090 , fImportance ( 0.0 )
00091 , fImportanceRef ( 1.0 )
00092 , fRuleEnsemble ( re )
00093 , fSSB ( 0 )
00094 , fSSBNeve ( 0 )
00095 , fLogger( new MsgLogger("RuleFit") )
00096 {
00097
00098 }
00099
00100
00101 TMVA::Rule::Rule()
00102 : fCut ( 0 )
00103 , fNorm ( 1.0 )
00104 , fSupport ( 0.0 )
00105 , fSigma ( 0.0 )
00106 , fCoefficient ( 0.0 )
00107 , fImportance ( 0.0 )
00108 , fImportanceRef ( 1.0 )
00109 , fRuleEnsemble ( 0 )
00110 , fSSB ( 0 )
00111 , fSSBNeve ( 0 )
00112 , fLogger( new MsgLogger("RuleFit") )
00113 {
00114
00115 }
00116
00117
00118 TMVA::Rule::~Rule()
00119 {
00120
00121 delete fCut;
00122 delete fLogger;
00123 }
00124
00125
00126 Bool_t TMVA::Rule::ContainsVariable(UInt_t iv) const
00127 {
00128
00129 Bool_t found = kFALSE;
00130 Bool_t doneLoop = kFALSE;
00131 UInt_t nvars = fCut->GetNvars();
00132 UInt_t i = 0;
00133
00134 while (!doneLoop) {
00135 found = (fCut->GetSelector(i) == iv);
00136 i++;
00137 doneLoop = (found || (i==nvars));
00138 }
00139 return found;
00140 }
00141
00142
00143 void TMVA::Rule::SetMsgType( EMsgType t )
00144 {
00145 fLogger->SetMinType(t);
00146 }
00147
00148
00149
00150 Bool_t TMVA::Rule::Equal( const Rule& other, Bool_t useCutValue, Double_t mindist ) const
00151 {
00152
00153
00154
00155
00156
00157
00158
00159
00160 Bool_t rval=kFALSE;
00161 if (mindist<0) useCutValue=kFALSE;
00162 Double_t d = RuleDist( other, useCutValue );
00163
00164 if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
00165 else rval = (!(d<0));
00166
00167 return rval;
00168 }
00169
00170
00171 Double_t TMVA::Rule::RuleDist( const Rule& other, Bool_t useCutValue ) const
00172 {
00173
00174
00175
00176
00177
00178 if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars()) return -1.0;
00179
00180 const UInt_t nvars = fCut->GetNvars();
00181
00182 Int_t sel;
00183 Double_t rms;
00184 Double_t smin;
00185 Double_t smax;
00186 Double_t vminA,vmaxA;
00187 Double_t vminB,vmaxB;
00188
00189
00190
00191
00192
00193
00194 UInt_t in = 0;
00195 Double_t sumdc2 = 0;
00196 Bool_t equal = true;
00197
00198 const RuleCut *otherCut = other.GetRuleCut();
00199 while ((equal) && (in<nvars)) {
00200
00201 equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
00202 (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
00203 (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
00204
00205 if (equal) {
00206 if (useCutValue) {
00207 sel = fCut->GetSelector(in);
00208 vminA = fCut->GetCutMin(in);
00209 vmaxA = fCut->GetCutMax(in);
00210 vminB = other.GetRuleCut()->GetCutMin(in);
00211 vmaxB = other.GetRuleCut()->GetCutMax(in);
00212
00213 rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
00214 smin=0;
00215 smax=0;
00216 if (fCut->GetCutDoMin(in))
00217 smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
00218 if (fCut->GetCutDoMax(in))
00219 smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
00220 sumdc2 += smin*smin + smax*smax;
00221
00222 }
00223 }
00224 in++;
00225 }
00226 if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0);
00227 else sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
00228
00229 return sumdc2;
00230 }
00231
00232
00233 Bool_t TMVA::Rule::operator==( const Rule& other ) const
00234 {
00235
00236
00237 return this->Equal( other, kTRUE, 1e-3 );
00238 }
00239
00240
00241 Bool_t TMVA::Rule::operator<( const Rule& other ) const
00242 {
00243
00244 return (fImportance < other.GetImportance());
00245 }
00246
00247
00248 ostream& TMVA::operator<< ( ostream& os, const Rule& rule )
00249 {
00250
00251 rule.Print( os );
00252 return os;
00253 }
00254
00255
00256 const TString & TMVA::Rule::GetVarName( Int_t i ) const
00257 {
00258
00259
00260 return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
00261 }
00262
00263
00264 void TMVA::Rule::Copy( const Rule& other )
00265 {
00266
00267 if(this != &other) {
00268 SetRuleEnsemble( other.GetRuleEnsemble() );
00269 fCut = new RuleCut( *(other.GetRuleCut()) );
00270 fSSB = other.GetSSB();
00271 fSSBNeve = other.GetSSBNeve();
00272 SetCoefficient(other.GetCoefficient());
00273 SetSupport( other.GetSupport() );
00274 SetSigma( other.GetSigma() );
00275 SetNorm( other.GetNorm() );
00276 CalcImportance();
00277 SetImportanceRef( other.GetImportanceRef() );
00278 }
00279 }
00280
00281
00282 void TMVA::Rule::Print( ostream& os ) const
00283 {
00284
00285 const UInt_t nvars = fCut->GetNvars();
00286 if (nvars<1) os << " *** WARNING - <EMPTY RULE> ***" << std::endl;
00287
00288 Int_t sel;
00289 Double_t valmin, valmax;
00290
00291 os << " Importance = " << Form("%1.4f", fImportance/fImportanceRef) << std::endl;
00292 os << " Coefficient = " << Form("%1.4f", fCoefficient) << std::endl;
00293 os << " Support = " << Form("%1.4f", fSupport) << std::endl;
00294 os << " S/(S+B) = " << Form("%1.4f", fSSB) << std::endl;
00295
00296 for ( UInt_t i=0; i<nvars; i++) {
00297 os << " ";
00298 sel = fCut->GetSelector(i);
00299 valmin = fCut->GetCutMin(i);
00300 valmax = fCut->GetCutMax(i);
00301
00302 os << Form("* Cut %2d",i+1) << " : " << std::flush;
00303 if (fCut->GetCutDoMin(i)) os << Form("%10.3g",valmin) << " < " << std::flush;
00304 else os << " " << std::flush;
00305 os << GetVarName(sel) << std::flush;
00306 if (fCut->GetCutDoMax(i)) os << " < " << Form("%10.3g",valmax) << std::flush;
00307 else os << " " << std::flush;
00308 os << std::endl;
00309 }
00310 }
00311
00312
00313 void TMVA::Rule::PrintLogger(const char *title) const
00314 {
00315
00316 const UInt_t nvars = fCut->GetNvars();
00317 if (nvars<1) Log() << kWARNING << "BUG TRAP: EMPTY RULE!!!" << Endl;
00318
00319 Int_t sel;
00320 Double_t valmin, valmax;
00321
00322 if (title) Log() << kINFO << title;
00323 Log() << kINFO
00324 << "Importance = " << Form("%1.4f", fImportance/fImportanceRef) << Endl;
00325
00326 for ( UInt_t i=0; i<nvars; i++) {
00327
00328 Log() << kINFO << " ";
00329 sel = fCut->GetSelector(i);
00330 valmin = fCut->GetCutMin(i);
00331 valmax = fCut->GetCutMax(i);
00332
00333 Log() << kINFO << Form("Cut %2d",i+1) << " : ";
00334 if (fCut->GetCutDoMin(i)) Log() << kINFO << Form("%10.3g",valmin) << " < ";
00335 else Log() << kINFO << " ";
00336 Log() << kINFO << GetVarName(sel);
00337 if (fCut->GetCutDoMax(i)) Log() << kINFO << " < " << Form("%10.3g",valmax);
00338 else Log() << kINFO << " ";
00339 Log() << Endl;
00340 }
00341 }
00342
00343
00344 void TMVA::Rule::PrintRaw( ostream& os ) const
00345 {
00346
00347 Int_t dp = os.precision();
00348 const UInt_t nvars = fCut->GetNvars();
00349 os << "Parameters: "
00350 << std::setprecision(10)
00351 << fImportance << " "
00352 << fImportanceRef << " "
00353 << fCoefficient << " "
00354 << fSupport << " "
00355 << fSigma << " "
00356 << fNorm << " "
00357 << fSSB << " "
00358 << fSSBNeve << " "
00359 << std::endl; \
00360 os << "N(cuts): " << nvars << std::endl;
00361 for ( UInt_t i=0; i<nvars; i++) {
00362 os << "Cut " << i << " : " << std::flush;
00363 os << fCut->GetSelector(i)
00364 << std::setprecision(10)
00365 << " " << fCut->GetCutMin(i)
00366 << " " << fCut->GetCutMax(i)
00367 << " " << (fCut->GetCutDoMin(i) ? "T":"F")
00368 << " " << (fCut->GetCutDoMax(i) ? "T":"F")
00369 << std::endl;
00370 }
00371 os << std::setprecision(dp);
00372 }
00373
00374
00375 void* TMVA::Rule::AddXMLTo( void* parent ) const
00376 {
00377 void* rule = gTools().AddChild( parent, "Rule" );
00378 const UInt_t nvars = fCut->GetNvars();
00379
00380 gTools().AddAttr( rule, "Importance", fImportance );
00381 gTools().AddAttr( rule, "Ref", fImportanceRef );
00382 gTools().AddAttr( rule, "Coeff", fCoefficient );
00383 gTools().AddAttr( rule, "Support", fSupport );
00384 gTools().AddAttr( rule, "Sigma", fSigma );
00385 gTools().AddAttr( rule, "Norm", fNorm );
00386 gTools().AddAttr( rule, "SSB", fSSB );
00387 gTools().AddAttr( rule, "SSBNeve", fSSBNeve );
00388 gTools().AddAttr( rule, "Nvars", nvars );
00389
00390 for (UInt_t i=0; i<nvars; i++) {
00391 void* cut = gTools().AddChild( rule, "Cut" );
00392 gTools().AddAttr( cut, "Selector", fCut->GetSelector(i) );
00393 gTools().AddAttr( cut, "Min", fCut->GetCutMin(i) );
00394 gTools().AddAttr( cut, "Max", fCut->GetCutMax(i) );
00395 gTools().AddAttr( cut, "DoMin", (fCut->GetCutDoMin(i) ? "T":"F") );
00396 gTools().AddAttr( cut, "DoMax", (fCut->GetCutDoMax(i) ? "T":"F") );
00397 }
00398
00399 return rule;
00400 }
00401
00402
00403 void TMVA::Rule::ReadFromXML( void* wghtnode )
00404 {
00405
00406 TString nodeName = TString( gTools().GetName(wghtnode) );
00407 if (nodeName != "Rule") Log() << kFATAL << "<ReadFromXML> Unexpected node name: " << nodeName << Endl;
00408
00409 gTools().ReadAttr( wghtnode, "Importance", fImportance );
00410 gTools().ReadAttr( wghtnode, "Ref", fImportanceRef );
00411 gTools().ReadAttr( wghtnode, "Coeff", fCoefficient );
00412 gTools().ReadAttr( wghtnode, "Support", fSupport );
00413 gTools().ReadAttr( wghtnode, "Sigma", fSigma );
00414 gTools().ReadAttr( wghtnode, "Norm", fNorm );
00415 gTools().ReadAttr( wghtnode, "SSB", fSSB );
00416 gTools().ReadAttr( wghtnode, "SSBNeve", fSSBNeve );
00417
00418 UInt_t nvars;
00419 gTools().ReadAttr( wghtnode, "Nvars", nvars );
00420 if (fCut) delete fCut;
00421 fCut = new RuleCut();
00422 fCut->SetNvars( nvars );
00423
00424
00425 void* ch = gTools().GetChild( wghtnode );
00426 UInt_t i = 0;
00427 UInt_t ui;
00428 Double_t d;
00429 Char_t c;
00430 while (ch) {
00431 gTools().ReadAttr( ch, "Selector", ui );
00432 fCut->SetSelector( i, ui );
00433 gTools().ReadAttr( ch, "Min", d );
00434 fCut->SetCutMin ( i, d );
00435 gTools().ReadAttr( ch, "Max", d );
00436 fCut->SetCutMax ( i, d );
00437 gTools().ReadAttr( ch, "DoMin", c );
00438 fCut->SetCutDoMin( i, (c == 'T' ? kTRUE : kFALSE ) );
00439 gTools().ReadAttr( ch, "DoMax", c );
00440 fCut->SetCutDoMax( i, (c == 'T' ? kTRUE : kFALSE ) );
00441
00442 i++;
00443 ch = gTools().GetNextChild(ch);
00444 }
00445
00446
00447 if (i != nvars) Log() << kFATAL << "<ReadFromXML> Mismatch in number of cuts: " << i << " != " << nvars << Endl;
00448 }
00449
00450
00451 void TMVA::Rule::ReadRaw( istream& istr )
00452 {
00453
00454
00455 TString dummy;
00456 UInt_t nvars;
00457 istr >> dummy
00458 >> fImportance
00459 >> fImportanceRef
00460 >> fCoefficient
00461 >> fSupport
00462 >> fSigma
00463 >> fNorm
00464 >> fSSB
00465 >> fSSBNeve;
00466
00467 istr >> dummy >> nvars;
00468 Double_t cutmin,cutmax;
00469 UInt_t sel,idum;
00470 Char_t bA, bB;
00471
00472 if (fCut) delete fCut;
00473 fCut = new RuleCut();
00474 fCut->SetNvars( nvars );
00475 for ( UInt_t i=0; i<nvars; i++) {
00476 istr >> dummy >> idum;
00477 istr >> dummy;
00478 istr >> sel >> cutmin >> cutmax >> bA >> bB;
00479 fCut->SetSelector(i,sel);
00480 fCut->SetCutMin(i,cutmin);
00481 fCut->SetCutMax(i,cutmax);
00482 fCut->SetCutDoMin(i,(bA=='T' ? kTRUE:kFALSE));
00483 fCut->SetCutDoMax(i,(bB=='T' ? kTRUE:kFALSE));
00484 }
00485 }