MethodTMlpANN.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: MethodTMlpANN.cxx 37154 2010-12-01 15:42:33Z evt $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
00003 /**********************************************************************************
00004  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00005  * Package: TMVA                                                                  *
00006  * Class  : MethodTMlpANN                                                         *
00007  * Web    : http://tmva.sourceforge.net                                           *
00008  *                                                                                *
00009  * Description:                                                                   *
00010  *      Implementation (see header for description)                               *
00011  *                                                                                *
00012  * Authors (alphabetical):                                                        *
00013  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00014  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00015  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00016  *                                                                                *
00017  * Copyright (c) 2005:                                                            *
00018  *      CERN, Switzerland                                                         *
00019  *      U. of Victoria, Canada                                                    *
00020  *      MPI-K Heidelberg, Germany                                                 *
00021  *                                                                                *
00022  * Redistribution and use in source and binary forms, with or without             *
00023  * modification, are permitted according to the terms listed in LICENSE           *
00024  * (http://tmva.sourceforge.net/LICENSE)                                          *
00025  **********************************************************************************/
00026 
00027 //_______________________________________________________________________
00028 /* Begin_Html
00029 
00030   This is the TMVA TMultiLayerPerceptron interface class. It provides the
00031   training and testing the ROOT internal MLP class in the TMVA framework.<be>
00032 
00033   Available learning methods:<br>
00034   <ul>
00035   <li>Stochastic      </li>
00036   <li>Batch           </li>
00037   <li>SteepestDescent </li>
00038   <li>RibierePolak    </li>
00039   <li>FletcherReeves  </li>
00040   <li>BFGS            </li>
00041   </ul>
00042 End_Html */
00043 //
00044 //  See the TMultiLayerPerceptron class description
00045 //  for details on this ANN.
00046 //
00047 //_______________________________________________________________________
00048 
00049 #include <cstdlib>
00050 #include <iostream>
00051 #include <fstream>
00052 
00053 #include "Riostream.h"
00054 #include "TLeaf.h"
00055 #include "TEventList.h"
00056 #include "TObjString.h"
00057 #include "TROOT.h"
00058 #include "TMultiLayerPerceptron.h"
00059 
00060 #include "TMVA/Config.h"
00061 #include "TMVA/MethodTMlpANN.h"
00062 
00063 #include "TMVA/ClassifierFactory.h"
00064 #ifndef ROOT_TMVA_Tools
00065 #include "TMVA/Tools.h"
00066 #endif
00067 
00068 // some additional TMlpANN options
00069 const Bool_t EnforceNormalization__=kTRUE;
00070 #if ROOT_VERSION_CODE > ROOT_VERSION(5,13,06)
00071 const TMultiLayerPerceptron::ELearningMethod LearningMethod__= TMultiLayerPerceptron::kStochastic;
00072 // const TMultiLayerPerceptron::ELearningMethod LearningMethod__= TMultiLayerPerceptron::kBatch;
00073 #else
00074 const TMultiLayerPerceptron::LearningMethod LearningMethod__= TMultiLayerPerceptron::kStochastic;
00075 #endif
00076 
00077 REGISTER_METHOD(TMlpANN)
00078 
00079 ClassImp(TMVA::MethodTMlpANN)
00080 
00081 //_______________________________________________________________________
00082 TMVA::MethodTMlpANN::MethodTMlpANN( const TString& jobName,
00083                                     const TString& methodTitle,
00084                                     DataSetInfo& theData,
00085                                     const TString& theOption,
00086                                     TDirectory* theTargetDir) :
00087    TMVA::MethodBase( jobName, Types::kTMlpANN, methodTitle, theData, theOption, theTargetDir ),
00088    fMLP(0),
00089    fNcycles(100),
00090    fValidationFraction(0.5),
00091    fLearningMethod( "" )
00092 {
00093    // standard constructor
00094 }
00095 
00096 //_______________________________________________________________________
00097 TMVA::MethodTMlpANN::MethodTMlpANN( DataSetInfo& theData,
00098                                     const TString& theWeightFile,
00099                                     TDirectory* theTargetDir ) :
00100    TMVA::MethodBase( Types::kTMlpANN, theData, theWeightFile, theTargetDir ),
00101    fMLP(0),
00102    fNcycles(100),
00103    fValidationFraction(0.5),
00104    fLearningMethod( "" )
00105 {
00106    // constructor from weight file
00107 }
00108 
00109 //_______________________________________________________________________
00110 Bool_t TMVA::MethodTMlpANN::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses,
00111                                              UInt_t /*numberTargets*/ )
00112 {
00113    // TMlpANN can handle classification with 2 classes
00114    if (type == Types::kClassification && numberClasses == 2) return kTRUE;
00115    return kFALSE;
00116 }
00117 
00118 
00119 //_______________________________________________________________________
00120 void TMVA::MethodTMlpANN::Init( void )
00121 {
00122    // default initialisations
00123 }
00124 
00125 //_______________________________________________________________________
00126 TMVA::MethodTMlpANN::~MethodTMlpANN( void )
00127 {
00128    // destructor 
00129    if (fMLP) delete fMLP;
00130 }
00131 
00132 //_______________________________________________________________________
00133 void TMVA::MethodTMlpANN::CreateMLPOptions( TString layerSpec )
00134 {
00135    // translates options from option string into TMlpANN language
00136 
00137    fHiddenLayer = ":";
00138 
00139    while (layerSpec.Length()>0) {
00140       TString sToAdd="";
00141       if (layerSpec.First(',')<0) {
00142          sToAdd = layerSpec;
00143          layerSpec = "";
00144       }
00145       else {
00146          sToAdd = layerSpec(0,layerSpec.First(','));
00147          layerSpec = layerSpec(layerSpec.First(',')+1,layerSpec.Length());
00148       }
00149       int nNodes = 0;
00150       if (sToAdd.BeginsWith("N")) { sToAdd.Remove(0,1); nNodes = GetNvar(); }
00151       nNodes += atoi(sToAdd);
00152       fHiddenLayer = Form( "%s%i:", (const char*)fHiddenLayer, nNodes );
00153    }
00154 
00155    // set input vars
00156    std::vector<TString>::iterator itrVar    = (*fInputVars).begin();
00157    std::vector<TString>::iterator itrVarEnd = (*fInputVars).end();
00158    fMLPBuildOptions = "";
00159    for (; itrVar != itrVarEnd; itrVar++) {
00160       if (EnforceNormalization__) fMLPBuildOptions += "@";
00161       TString myVar = *itrVar; ;
00162       fMLPBuildOptions += myVar;
00163       fMLPBuildOptions += ",";
00164    }
00165    fMLPBuildOptions.Chop(); // remove last ","
00166 
00167    // prepare final options for MLP kernel
00168    fMLPBuildOptions += fHiddenLayer;
00169    fMLPBuildOptions += "type";
00170 
00171    Log() << kINFO << "Use " << fNcycles << " training cycles" << Endl;
00172    Log() << kINFO << "Use configuration (nodes per hidden layer): " << fHiddenLayer << Endl;
00173 }
00174 
00175 //_______________________________________________________________________
00176 void TMVA::MethodTMlpANN::DeclareOptions()
00177 {
00178    // define the options (their key words) that can be set in the option string
00179    // know options:
00180    // NCycles       <integer>    Number of training cycles (too many cycles could overtrain the network)
00181    // HiddenLayers  <string>     Layout of the hidden layers (nodes per layer)
00182    //   * specifiactions for each hidden layer are separated by commata
00183    //   * for each layer the number of nodes can be either absolut (simply a number)
00184    //        or relative to the number of input nodes to the neural net (N)
00185    //   * there is always a single node in the output layer
00186    //   example: a net with 6 input nodes and "Hiddenlayers=N-1,N-2" has 6,5,4,1 nodes in the
00187    //   layers 1,2,3,4, repectively
00188    DeclareOptionRef( fNcycles    = 200,       "NCycles",      "Number of training cycles" );
00189    DeclareOptionRef( fLayerSpec  = "N,N-1",   "HiddenLayers", "Specification of hidden layer architecture (N stands for number of variables; any integers may also be used)" );
00190 
00191    DeclareOptionRef( fValidationFraction = 0.5, "ValidationFraction",
00192                      "Fraction of events in training tree used for cross validation" );
00193 
00194    DeclareOptionRef( fLearningMethod = "Stochastic", "LearningMethod", "Learning method" );
00195    AddPreDefVal( TString("Stochastic") );
00196    AddPreDefVal( TString("Batch") );
00197    AddPreDefVal( TString("SteepestDescent") );
00198    AddPreDefVal( TString("RibierePolak") );
00199    AddPreDefVal( TString("FletcherReeves") );
00200    AddPreDefVal( TString("BFGS") );
00201 }
00202 
00203 //_______________________________________________________________________
00204 void TMVA::MethodTMlpANN::ProcessOptions()
00205 {
00206    // builds the neural network as specified by the user
00207    CreateMLPOptions(fLayerSpec);
00208 
00209    if (IgnoreEventsWithNegWeightsInTraining()) {
00210       Log() << kFATAL << "Mechanism to ignore events with negative weights in training not available for method"
00211             << GetMethodTypeName()
00212             << " --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
00213             << Endl;
00214    }
00215 }
00216 
00217 //_______________________________________________________________________
00218 Double_t TMVA::MethodTMlpANN::GetMvaValue( Double_t* err, Double_t* errUpper )
00219 {
00220    // calculate the value of the neural net for the current event
00221    const Event* ev = GetEvent();
00222    static Double_t* d = new Double_t[Data()->GetNVariables()];
00223    for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
00224       d[ivar] = (Double_t)ev->GetValue(ivar);
00225    }
00226    Double_t mvaVal = fMLP->Evaluate(0,d);
00227 
00228    // cannot determine error
00229    NoErrorCalc(err, errUpper);
00230 
00231    return mvaVal;
00232 }
00233 
00234 //_______________________________________________________________________
00235 void TMVA::MethodTMlpANN::Train( void )
00236 {
00237    // performs TMlpANN training
00238    // available learning methods:
00239    //
00240    //       TMultiLayerPerceptron::kStochastic
00241    //       TMultiLayerPerceptron::kBatch
00242    //       TMultiLayerPerceptron::kSteepestDescent
00243    //       TMultiLayerPerceptron::kRibierePolak
00244    //       TMultiLayerPerceptron::kFletcherReeves
00245    //       TMultiLayerPerceptron::kBFGS
00246    //
00247    // TMultiLayerPerceptron wants test and training tree at once
00248    // so merge the training and testing trees from the MVA factory first:
00249 
00250    Int_t type;
00251    Float_t weight;
00252    const Long_t basketsize = 128000;
00253    Float_t* vArr = new Float_t[GetNvar()]; 
00254 
00255    TTree *localTrainingTree = new TTree( "TMLPtrain", "Local training tree for TMlpANN" );
00256    localTrainingTree->Branch( "type",       &type,        "type/I",        basketsize );
00257    localTrainingTree->Branch( "weight",     &weight,      "weight/F",      basketsize );
00258    
00259    for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00260       const char* myVar = GetInternalVarName(ivar).Data();
00261       localTrainingTree->Branch( myVar, &vArr[ivar], Form("Var%02i/F", ivar), basketsize );
00262    }
00263    
00264    for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
00265       const Event *ev = GetEvent(ievt);
00266       for (UInt_t i=0; i<GetNvar(); i++) {
00267          vArr[i] = ev->GetValue( i );
00268       }
00269       type   = DataInfo().IsSignal( ev ) ? 1 : 0;
00270       weight = ev->GetWeight();
00271       localTrainingTree->Fill();
00272    }
00273 
00274    // These are the event lists for the mlp train method
00275    // first events in the tree are for training
00276    // the rest for internal testing (cross validation)...
00277    // NOTE: the training events are ordered: first part is signal, second part background
00278    TString trainList = "Entry$<";
00279    trainList += 1.0-fValidationFraction;
00280    trainList += "*";
00281    trainList += (Int_t)Data()->GetNEvtSigTrain();
00282    trainList += " || (Entry$>";
00283    trainList += (Int_t)Data()->GetNEvtSigTrain();
00284    trainList += " && Entry$<";
00285    trainList += (Int_t)(Data()->GetNEvtSigTrain() + (1.0 - fValidationFraction)*Data()->GetNEvtBkgdTrain());
00286    trainList += ")";
00287    TString testList  = TString("!(") + trainList + ")";
00288 
00289    // print the requirements
00290    Log() << kINFO << "Requirement for training   events: \"" << trainList << "\"" << Endl;
00291    Log() << kINFO << "Requirement for validation events: \"" << testList << "\"" << Endl;
00292 
00293    // localTrainingTree->Print();
00294 
00295    // create NN
00296    if (fMLP != 0) { delete fMLP; fMLP = 0; }
00297    fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(),
00298                                      localTrainingTree,
00299                                      trainList,
00300                                      testList );
00301    fMLP->SetEventWeight( "weight" );
00302 
00303    // set learning method
00304 #if ROOT_VERSION_CODE > ROOT_VERSION(5,13,06)
00305    TMultiLayerPerceptron::ELearningMethod learningMethod = TMultiLayerPerceptron::kStochastic;
00306 #else
00307    TMultiLayerPerceptron::LearningMethod  learningMethod = TMultiLayerPerceptron::kStochastic;
00308 #endif
00309 
00310    fLearningMethod.ToLower();
00311    if      (fLearningMethod == "stochastic"      ) learningMethod = TMultiLayerPerceptron::kStochastic;
00312    else if (fLearningMethod == "batch"           ) learningMethod = TMultiLayerPerceptron::kBatch;
00313    else if (fLearningMethod == "steepestdescent" ) learningMethod = TMultiLayerPerceptron::kSteepestDescent;
00314    else if (fLearningMethod == "ribierepolak"    ) learningMethod = TMultiLayerPerceptron::kRibierePolak;
00315    else if (fLearningMethod == "fletcherreeves"  ) learningMethod = TMultiLayerPerceptron::kFletcherReeves;
00316    else if (fLearningMethod == "bfgs"            ) learningMethod = TMultiLayerPerceptron::kBFGS;
00317    else {
00318       Log() << kFATAL << "Unknown Learning Method: \"" << fLearningMethod << "\"" << Endl;
00319    }
00320    fMLP->SetLearningMethod( learningMethod );
00321 
00322    // train NN
00323    fMLP->Train(fNcycles, "text,update=50" );
00324 
00325    // write weights to File;
00326    // this is not nice, but fMLP gets deleted at the end of Train()
00327    delete localTrainingTree;
00328    delete [] vArr;
00329 }
00330 
00331 
00332 //_______________________________________________________________________
00333 void TMVA::MethodTMlpANN::AddWeightsXMLTo( void* parent ) const
00334 {
00335    // write weights to xml file
00336 
00337    // first the architecture
00338    void *wght = gTools().AddChild(parent, "Weights");
00339    void* arch = gTools().AddChild( wght, "Architecture" );
00340    gTools().AddAttr( arch, "BuildOptions", fMLPBuildOptions.Data() );
00341 
00342    // dump weights first in temporary txt file, read from there into xml
00343    fMLP->DumpWeights( "weights/TMlp.nn.weights.temp" );
00344    std::ifstream inf( "weights/TMlp.nn.weights.temp" );
00345    char temp[256];
00346    TString data("");
00347    void *ch=NULL;
00348    while (inf.getline(temp,256)) {
00349       TString dummy(temp);
00350       //std::cout << dummy << std::endl; // remove annoying debug printout with std::cout
00351       if (dummy.BeginsWith('#')) {
00352          if (ch!=0) gTools().AddRawLine( ch, data.Data() );
00353          dummy = dummy.Strip(TString::kLeading, '#');
00354          dummy = dummy(0,dummy.First(' '));
00355          ch = gTools().AddChild(wght, dummy);
00356          data.Resize(0);
00357          continue;
00358       }
00359       data += (dummy + " ");
00360    }
00361    if (ch != 0) gTools().AddRawLine( ch, data.Data() );
00362 
00363    inf.close();
00364 }
00365 
00366 //_______________________________________________________________________
00367 void  TMVA::MethodTMlpANN::ReadWeightsFromXML( void* wghtnode )
00368 {
00369    // rebuild temporary textfile from xml weightfile and load this
00370    // file into MLP
00371    void* ch = gTools().GetChild(wghtnode);
00372    gTools().ReadAttr( ch, "BuildOptions", fMLPBuildOptions );
00373 
00374    ch = gTools().GetNextChild(ch);
00375    const char* fname = "weights/TMlp.nn.weights.temp";
00376    std::ofstream fout( fname );
00377    double temp1=0,temp2=0;
00378    while (ch) {
00379       const char* nodecontent = gTools().GetContent(ch);
00380       std::stringstream content(nodecontent);
00381       if (strcmp(gTools().GetName(ch),"input")==0) {
00382          fout << "#input normalization" << std::endl;
00383          while ((content >> temp1) &&(content >> temp2)) {
00384             fout << temp1 << " " << temp2 << std::endl;
00385          }
00386       }
00387       if (strcmp(gTools().GetName(ch),"output")==0) {
00388          fout << "#output normalization" << std::endl;
00389          while ((content >> temp1) &&(content >> temp2)) {
00390             fout << temp1 << " " << temp2 << std::endl;
00391          }
00392       }
00393       if (strcmp(gTools().GetName(ch),"neurons")==0) {
00394          fout << "#neurons weights" << std::endl;         
00395          while (content >> temp1) {
00396             fout << temp1 << std::endl;
00397          }
00398       }
00399       if (strcmp(gTools().GetName(ch),"synapses")==0) {
00400          fout << "#synapses weights" ;         
00401          while (content >> temp1) {
00402             fout << std::endl << temp1 ;                
00403          }
00404       }
00405       ch = gTools().GetNextChild(ch);
00406    }
00407    fout.close();;
00408 
00409    // Here we create a dummy tree necessary to create a minimal NN
00410    // to be used for testing, evaluation and application
00411    static Double_t* d = new Double_t[Data()->GetNVariables()] ;
00412    static Int_t type;
00413 
00414    gROOT->cd();
00415    TTree * dummyTree = new TTree("dummy","Empty dummy tree", 1);
00416    for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
00417       TString vn = DataInfo().GetVariableInfo(ivar).GetInternalName();
00418       dummyTree->Branch(Form("%s",vn.Data()), d+ivar, Form("%s/D",vn.Data()));
00419    }
00420    dummyTree->Branch("type", &type, "type/I");
00421 
00422    if (fMLP != 0) { delete fMLP; fMLP = 0; }
00423    fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
00424    fMLP->LoadWeights( fname );
00425 }
00426  
00427 //_______________________________________________________________________
00428 void  TMVA::MethodTMlpANN::ReadWeightsFromStream( istream& istr )
00429 {
00430    // read weights from stream
00431    // since the MLP can not read from the stream, we
00432    // 1st: write the weights to temporary file
00433    std::ofstream fout( "./TMlp.nn.weights.temp" );
00434    fout << istr.rdbuf();
00435    fout.close();
00436    // 2nd: load the weights from the temporary file into the MLP
00437    // the MLP is already build
00438    Log() << kINFO << "Load TMLP weights into " << fMLP << Endl;
00439 
00440    Double_t* d = new Double_t[Data()->GetNVariables()] ; 
00441    static Int_t type;
00442    gROOT->cd();
00443    TTree * dummyTree = new TTree("dummy","Empty dummy tree", 1);
00444    for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
00445       TString vn = DataInfo().GetVariableInfo(ivar).GetLabel();
00446       dummyTree->Branch(Form("%s",vn.Data()), d+ivar, Form("%s/D",vn.Data()));
00447    }
00448    dummyTree->Branch("type", &type, "type/I");
00449 
00450    if (fMLP != 0) { delete fMLP; fMLP = 0; }
00451    fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
00452 
00453    fMLP->LoadWeights( "./TMlp.nn.weights.temp" );
00454    // here we can delete the temporary file
00455    // how?
00456    delete [] d;
00457 }
00458 
00459 //_______________________________________________________________________
00460 void TMVA::MethodTMlpANN::MakeClass( const TString& theClassFileName ) const
00461 {
00462    // create reader class for classifier -> overwrites base class function
00463    // create specific class for TMultiLayerPerceptron
00464 
00465    // the default consists of
00466    TString classFileName = "";
00467    if (theClassFileName == "")
00468       classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class";
00469    else
00470       classFileName = theClassFileName;
00471 
00472    Log() << kINFO << "Creating specific (TMultiLayerPerceptron) standalone response class: " << Endl;
00473    fMLP->Export( classFileName.Data() );
00474 }
00475 
00476 //_______________________________________________________________________
00477 void TMVA::MethodTMlpANN::MakeClassSpecific( std::ostream& /*fout*/, const TString& /*className*/ ) const
00478 {
00479    // write specific classifier response
00480    // nothing to do here - all taken care of by TMultiLayerPerceptron
00481 }
00482 
00483 //_______________________________________________________________________
00484 void TMVA::MethodTMlpANN::GetHelpMessage() const
00485 {
00486    // get help message text
00487    //
00488    // typical length of text line: 
00489    //         "|--------------------------------------------------------------|"
00490    Log() << Endl;
00491    Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00492    Log() << Endl;
00493    Log() << "This feed-forward multilayer perceptron neural network is the " << Endl;
00494    Log() << "standard implementation distributed with ROOT (class TMultiLayerPerceptron)." << Endl;
00495    Log() << Endl;
00496    Log() << "Detailed information is available here:" << Endl;
00497    if (gConfig().WriteOptionsReference()) {
00498       Log() << "<a href=\"http://root.cern.ch/root/html/TMultiLayerPerceptron.html\">";
00499       Log() << "http://root.cern.ch/root/html/TMultiLayerPerceptron.html</a>" << Endl;
00500    }
00501    else Log() << "http://root.cern.ch/root/html/TMultiLayerPerceptron.html" << Endl;
00502    Log() << Endl;
00503 }

Generated on Tue Jul 5 15:25:04 2011 for ROOT_528-00b_version by  doxygen 1.5.1