MethodCompositeBase.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: MethodCompositeBase.cxx 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : MethodCompositeBase                                                   *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Virtual base class for all MVA method                                     *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
00016  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00017  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00018  *      Or Cohen        <orcohenor@gmail.com>    - Weizmann Inst., Israel         *
00019  *                                                                                *
00020  * Copyright (c) 2005:                                                            *
00021  *      CERN, Switzerland                                                         *
00022  *      U. of Victoria, Canada                                                    *
00023  *      MPI-K Heidelberg, Germany                                                 *
00024  *      LAPP, Annecy, France                                                      *
00025  *                                                                                *
00026  * Redistribution and use in source and binary forms, with or without             *
00027  * modification, are permitted according to the terms listed in LICENSE           *
00028  * (http://tmva.sourceforge.net/LICENSE)                                          *
00029  **********************************************************************************/
00030 
00031 //_______________________________________________________________________
00032 //
00033 // This class is virtual class meant to combine more than one classifier//
00034 // together. The training of the classifiers is done by classes that are//
00035 // derived from this one, while the saving and loading of weights file  //
00036 // and the evaluation is done here.                                     //
00037 //_______________________________________________________________________
00038 
00039 #include <algorithm>
00040 #include <iomanip>
00041 #include <vector>
00042 
00043 #include "Riostream.h"
00044 #include "TRandom3.h"
00045 #include "TMath.h"
00046 #include "TObjString.h"
00047 
00048 #include "TMVA/MethodCompositeBase.h"
00049 #include "TMVA/MethodBoost.h"
00050 #include "TMVA/MethodBase.h"
00051 #include "TMVA/Tools.h"
00052 #include "TMVA/Types.h"
00053 #include "TMVA/Factory.h"
00054 #include "TMVA/ClassifierFactory.h"
00055 
00056 using std::vector;
00057 
00058 ClassImp(TMVA::MethodCompositeBase)
00059 
00060 //_______________________________________________________________________
00061 TMVA::MethodCompositeBase::MethodCompositeBase( const TString& jobName, 
00062                                                 Types::EMVA methodType,
00063                                                 const TString& methodTitle,
00064                                                 DataSetInfo& theData,
00065                                                 const TString& theOption,
00066                                                 TDirectory* theTargetDir )
00067    : TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption, theTargetDir ),
00068      fMethodIndex(0)
00069 {}
00070 
00071 //_______________________________________________________________________
00072 TMVA::MethodCompositeBase::MethodCompositeBase( Types::EMVA methodType,
00073                                                 DataSetInfo& dsi,
00074                                                 const TString& weightFile, 
00075                                                 TDirectory* theTargetDir )
00076    : TMVA::MethodBase( methodType, dsi, weightFile, theTargetDir ),
00077      fMethodIndex(0)
00078 {}
00079 
00080 //_______________________________________________________________________
00081 TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const TString &methodTitle ) const
00082 {
00083    // returns pointer to MVA that corresponds to given method title
00084    vector<IMethod*>::const_iterator itrMethod    = fMethods.begin();
00085    vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
00086 
00087    for (; itrMethod != itrMethodEnd; itrMethod++) {
00088       MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);    
00089       if ( (mva->GetMethodName())==methodTitle ) return mva;
00090    }
00091    return 0;
00092 }
00093 
00094 //_______________________________________________________________________
00095 TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const Int_t index ) const
00096 {
00097    // returns pointer to MVA that corresponds to given method index
00098    vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
00099    if (itrMethod<fMethods.end()) return *itrMethod;
00100    else                          return 0;
00101 }
00102 
00103 
00104 //_______________________________________________________________________
00105 void TMVA::MethodCompositeBase::AddWeightsXMLTo( void* parent ) const 
00106 {
00107    void* wght = gTools().AddChild(parent, "Weights");
00108    gTools().AddAttr( wght, "NMethods",   fMethods.size()   );
00109    for (UInt_t i=0; i< fMethods.size(); i++) 
00110    {
00111       void* methxml = gTools().AddChild( wght, "Method" );
00112       MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
00113       gTools().AddAttr(methxml,"Index",          i ); 
00114       gTools().AddAttr(methxml,"Weight",         fMethodWeight[i]); 
00115       gTools().AddAttr(methxml,"MethodSigCut",   method->GetSignalReferenceCut());
00116       gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
00117       gTools().AddAttr(methxml,"MethodName",     method->GetMethodName()   ); 
00118       gTools().AddAttr(methxml,"JobName",        method->GetJobName());
00119       gTools().AddAttr(methxml,"Options",        method->GetOptions()); 
00120       method->AddWeightsXMLTo(methxml);
00121    }
00122 }
00123 
00124 //_______________________________________________________________________
00125 TMVA::MethodCompositeBase::~MethodCompositeBase( void )
00126 {
00127    // delete methods
00128    vector<IMethod*>::iterator itrMethod = fMethods.begin();
00129    for (; itrMethod != fMethods.end(); itrMethod++) {
00130       Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;    
00131       delete (*itrMethod);
00132    }
00133    fMethods.clear();
00134 }
00135 
00136 //_______________________________________________________________________
00137 void TMVA::MethodCompositeBase::ReadWeightsFromXML( void* wghtnode ) 
00138 {
00139    // XML streamer
00140    UInt_t nMethods;
00141    TString methodName, methodTypeName, jobName, optionString;
00142 
00143    for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
00144    fMethods.clear();
00145    fMethodWeight.clear();
00146    gTools().ReadAttr( wghtnode, "NMethods",  nMethods );
00147    void* ch = gTools().GetChild(wghtnode);
00148    for (UInt_t i=0; i< nMethods; i++) {
00149       Double_t methodWeight, methodSigCut;
00150       gTools().ReadAttr( ch, "Weight",   methodWeight   );
00151       gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
00152       gTools().ReadAttr( ch, "MethodTypeName",  methodTypeName );
00153       gTools().ReadAttr( ch, "MethodName",  methodName );
00154       gTools().ReadAttr( ch, "JobName",  jobName );
00155       gTools().ReadAttr( ch, "Options",  optionString );
00156 
00157       if (i==0){
00158          // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
00159          ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName,  optionString );
00160       }
00161       fMethods.push_back(ClassifierFactory::Instance().Create(
00162          std::string(methodTypeName),jobName, methodName,DataInfo(),optionString));
00163 
00164       fMethodWeight.push_back(methodWeight);
00165       MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
00166 
00167       if(meth==0)
00168          Log() << kFATAL << "Could not read method from XML" << Endl;
00169 
00170       void* methXML = gTools().GetChild(ch);
00171       meth->SetupMethod();
00172       meth->ReadWeightsFromXML(methXML);
00173       meth->SetMsgType(kWARNING);
00174       meth->ParseOptions();
00175       meth->ProcessSetup();
00176       meth->CheckSetup();
00177       meth->SetSignalReferenceCut(methodSigCut);
00178 
00179       ch = gTools().GetNextChild(ch);
00180    }
00181    //Log() << kINFO << "Reading methods from XML done " << Endl;
00182 }
00183 
00184 //_______________________________________________________________________
00185 void  TMVA::MethodCompositeBase::ReadWeightsFromStream( istream& istr )
00186 {
00187    // text streamer
00188    TString var, dummy;
00189    TString methodName, methodTitle=GetMethodName(),
00190     jobName=GetJobName(),optionString=GetOptions();
00191    UInt_t methodNum; Double_t methodWeight;
00192    // and read the Weights (BDT coefficients)
00193    // coverity[tainted_data_argument]
00194    istr >> dummy >> methodNum;
00195    Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
00196    for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
00197    fMethods.clear();
00198    fMethodWeight.clear();
00199    for (UInt_t i=0; i<methodNum; i++) {
00200       istr >> dummy >> methodName >>  dummy >> fMethodIndex >> dummy >> methodWeight;
00201       if ((UInt_t)fMethodIndex != i) {
00202          Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
00203                << fMethodIndex << " i=" << i
00204                << " MethodName " << methodName
00205                << " dummy " << dummy
00206                << " MethodWeight= " << methodWeight
00207                << Endl;
00208       }
00209       if (GetMethodType() != Types::kBoost || i==0) {
00210          istr >> dummy >> jobName;
00211          istr >> dummy >> methodTitle;
00212          istr >> dummy >> optionString;
00213          if (GetMethodType() == Types::kBoost)
00214             ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle,  optionString );
00215       }
00216       else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fMethodIndex);
00217       fMethods.push_back(ClassifierFactory::Instance().Create( std::string(methodName), jobName,
00218                                                                methodTitle,DataInfo(), optionString) );
00219       fMethodWeight.push_back( methodWeight );
00220       if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
00221          m->ReadWeightsFromStream(istr);
00222    }
00223 }
00224 
00225 //_______________________________________________________________________
00226 Double_t TMVA::MethodCompositeBase::GetMvaValue( Double_t* err, Double_t* errUpper )
00227 {
00228    // return composite MVA response
00229    Double_t mvaValue = 0;
00230    for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
00231 
00232    // cannot determine error
00233    NoErrorCalc(err, errUpper);
00234 
00235    return mvaValue;
00236 }

Generated on Tue Jul 5 15:25:02 2011 for ROOT_528-00b_version by  doxygen 1.5.1