00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039 #include <algorithm>
00040 #include <iomanip>
00041 #include <vector>
00042
00043 #include "Riostream.h"
00044 #include "TRandom3.h"
00045 #include "TMath.h"
00046 #include "TObjString.h"
00047
00048 #include "TMVA/MethodCompositeBase.h"
00049 #include "TMVA/MethodBoost.h"
00050 #include "TMVA/MethodBase.h"
00051 #include "TMVA/Tools.h"
00052 #include "TMVA/Types.h"
00053 #include "TMVA/Factory.h"
00054 #include "TMVA/ClassifierFactory.h"
00055
00056 using std::vector;
00057
00058 ClassImp(TMVA::MethodCompositeBase)
00059
00060
00061 TMVA::MethodCompositeBase::MethodCompositeBase( const TString& jobName,
00062 Types::EMVA methodType,
00063 const TString& methodTitle,
00064 DataSetInfo& theData,
00065 const TString& theOption,
00066 TDirectory* theTargetDir )
00067 : TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption, theTargetDir ),
00068 fMethodIndex(0)
00069 {}
00070
00071
00072 TMVA::MethodCompositeBase::MethodCompositeBase( Types::EMVA methodType,
00073 DataSetInfo& dsi,
00074 const TString& weightFile,
00075 TDirectory* theTargetDir )
00076 : TMVA::MethodBase( methodType, dsi, weightFile, theTargetDir ),
00077 fMethodIndex(0)
00078 {}
00079
00080
00081 TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const TString &methodTitle ) const
00082 {
00083
00084 vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
00085 vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
00086
00087 for (; itrMethod != itrMethodEnd; itrMethod++) {
00088 MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
00089 if ( (mva->GetMethodName())==methodTitle ) return mva;
00090 }
00091 return 0;
00092 }
00093
00094
00095 TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const Int_t index ) const
00096 {
00097
00098 vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
00099 if (itrMethod<fMethods.end()) return *itrMethod;
00100 else return 0;
00101 }
00102
00103
00104
00105 void TMVA::MethodCompositeBase::AddWeightsXMLTo( void* parent ) const
00106 {
00107 void* wght = gTools().AddChild(parent, "Weights");
00108 gTools().AddAttr( wght, "NMethods", fMethods.size() );
00109 for (UInt_t i=0; i< fMethods.size(); i++)
00110 {
00111 void* methxml = gTools().AddChild( wght, "Method" );
00112 MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
00113 gTools().AddAttr(methxml,"Index", i );
00114 gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
00115 gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
00116 gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
00117 gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
00118 gTools().AddAttr(methxml,"JobName", method->GetJobName());
00119 gTools().AddAttr(methxml,"Options", method->GetOptions());
00120 method->AddWeightsXMLTo(methxml);
00121 }
00122 }
00123
00124
00125 TMVA::MethodCompositeBase::~MethodCompositeBase( void )
00126 {
00127
00128 vector<IMethod*>::iterator itrMethod = fMethods.begin();
00129 for (; itrMethod != fMethods.end(); itrMethod++) {
00130 Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
00131 delete (*itrMethod);
00132 }
00133 fMethods.clear();
00134 }
00135
00136
00137 void TMVA::MethodCompositeBase::ReadWeightsFromXML( void* wghtnode )
00138 {
00139
00140 UInt_t nMethods;
00141 TString methodName, methodTypeName, jobName, optionString;
00142
00143 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
00144 fMethods.clear();
00145 fMethodWeight.clear();
00146 gTools().ReadAttr( wghtnode, "NMethods", nMethods );
00147 void* ch = gTools().GetChild(wghtnode);
00148 for (UInt_t i=0; i< nMethods; i++) {
00149 Double_t methodWeight, methodSigCut;
00150 gTools().ReadAttr( ch, "Weight", methodWeight );
00151 gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
00152 gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
00153 gTools().ReadAttr( ch, "MethodName", methodName );
00154 gTools().ReadAttr( ch, "JobName", jobName );
00155 gTools().ReadAttr( ch, "Options", optionString );
00156
00157 if (i==0){
00158
00159 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
00160 }
00161 fMethods.push_back(ClassifierFactory::Instance().Create(
00162 std::string(methodTypeName),jobName, methodName,DataInfo(),optionString));
00163
00164 fMethodWeight.push_back(methodWeight);
00165 MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
00166
00167 if(meth==0)
00168 Log() << kFATAL << "Could not read method from XML" << Endl;
00169
00170 void* methXML = gTools().GetChild(ch);
00171 meth->SetupMethod();
00172 meth->ReadWeightsFromXML(methXML);
00173 meth->SetMsgType(kWARNING);
00174 meth->ParseOptions();
00175 meth->ProcessSetup();
00176 meth->CheckSetup();
00177 meth->SetSignalReferenceCut(methodSigCut);
00178
00179 ch = gTools().GetNextChild(ch);
00180 }
00181
00182 }
00183
00184
00185 void TMVA::MethodCompositeBase::ReadWeightsFromStream( istream& istr )
00186 {
00187
00188 TString var, dummy;
00189 TString methodName, methodTitle=GetMethodName(),
00190 jobName=GetJobName(),optionString=GetOptions();
00191 UInt_t methodNum; Double_t methodWeight;
00192
00193
00194 istr >> dummy >> methodNum;
00195 Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
00196 for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
00197 fMethods.clear();
00198 fMethodWeight.clear();
00199 for (UInt_t i=0; i<methodNum; i++) {
00200 istr >> dummy >> methodName >> dummy >> fMethodIndex >> dummy >> methodWeight;
00201 if ((UInt_t)fMethodIndex != i) {
00202 Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
00203 << fMethodIndex << " i=" << i
00204 << " MethodName " << methodName
00205 << " dummy " << dummy
00206 << " MethodWeight= " << methodWeight
00207 << Endl;
00208 }
00209 if (GetMethodType() != Types::kBoost || i==0) {
00210 istr >> dummy >> jobName;
00211 istr >> dummy >> methodTitle;
00212 istr >> dummy >> optionString;
00213 if (GetMethodType() == Types::kBoost)
00214 ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
00215 }
00216 else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fMethodIndex);
00217 fMethods.push_back(ClassifierFactory::Instance().Create( std::string(methodName), jobName,
00218 methodTitle,DataInfo(), optionString) );
00219 fMethodWeight.push_back( methodWeight );
00220 if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
00221 m->ReadWeightsFromStream(istr);
00222 }
00223 }
00224
00225
00226 Double_t TMVA::MethodCompositeBase::GetMvaValue( Double_t* err, Double_t* errUpper )
00227 {
00228
00229 Double_t mvaValue = 0;
00230 for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
00231
00232
00233 NoErrorCalc(err, errUpper);
00234
00235 return mvaValue;
00236 }