MethodKNN.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: MethodKNN.cxx 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Rustem Ospanov 
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : MethodKNN                                                             *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation                                                            *
00012  *                                                                                *
00013  * Author:                                                                        *
00014  *      Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA             *
00015  *                                                                                *
00016  * Copyright (c) 2007:                                                            *
00017  *      CERN, Switzerland                                                         * 
00018  *      MPI-K Heidelberg, Germany                                                 * 
00019  *      U. of Texas at Austin, USA                                                *
00020  *                                                                                *
00021  * Redistribution and use in source and binary forms, with or without             *
00022  * modification, are permitted according to the terms listed in LICENSE           *
00023  * (http://tmva.sourceforge.net/LICENSE)                                          *
00024  **********************************************************************************/
00025 
00026 //////////////////////////////////////////////////////////////////////////
00027 //                                                                      //
00028 // MethodKNN                                                            //
00029 //                                                                      //
00030 // Analysis of k-nearest neighbor                                       //
00031 //                                                                      //
00032 //////////////////////////////////////////////////////////////////////////
00033 
00034 // C/C++
00035 #include <cmath>
00036 #include <string>
00037 #include <cstdlib>
00038 
00039 // ROOT
00040 #include "TFile.h"
00041 #include "TMath.h"
00042 #include "TTree.h"
00043 
00044 // TMVA
00045 #include "TMVA/ClassifierFactory.h"
00046 #include "TMVA/MethodKNN.h"
00047 #include "TMVA/Ranking.h"
00048 #include "TMVA/Tools.h"
00049 
00050 REGISTER_METHOD(KNN)
00051 
00052 ClassImp(TMVA::MethodKNN)
00053 
00054 //_______________________________________________________________________
00055 TMVA::MethodKNN::MethodKNN( const TString& jobName,
00056                             const TString& methodTitle,
00057                             DataSetInfo& theData, 
00058                             const TString& theOption,
00059                             TDirectory* theTargetDir ) 
00060    : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption, theTargetDir)
00061    , fSumOfWeightsS(0)
00062    , fSumOfWeightsB(0)
00063    , fModule(0)
00064    , fnkNN(0)
00065    , fBalanceDepth(0)
00066    , fScaleFrac(0)
00067    , fSigmaFact(0)
00068    , fTrim(kFALSE)
00069    , fUseKernel(kFALSE)
00070    , fUseWeight(kFALSE)
00071    , fUseLDA(kFALSE)
00072    , fTreeOptDepth(0)
00073 {
00074    // standard constructor
00075 }
00076 
00077 //_______________________________________________________________________
00078 TMVA::MethodKNN::MethodKNN( DataSetInfo& theData, 
00079                             const TString& theWeightFile,  
00080                             TDirectory* theTargetDir ) 
00081    : TMVA::MethodBase( Types::kKNN, theData, theWeightFile, theTargetDir)
00082    , fSumOfWeightsS(0)
00083    , fSumOfWeightsB(0)
00084    , fModule(0)
00085    , fnkNN(0)
00086    , fBalanceDepth(0)
00087    , fScaleFrac(0)
00088    , fSigmaFact(0)
00089    , fTrim(kFALSE)
00090    , fUseKernel(kFALSE)
00091    , fUseWeight(kFALSE)
00092    , fUseLDA(kFALSE)
00093    , fTreeOptDepth(0)
00094 {
00095    // constructor from weight file
00096 }
00097 
00098 //_______________________________________________________________________
00099 TMVA::MethodKNN::~MethodKNN()
00100 {
00101    // destructor
00102    if (fModule) delete fModule;
00103 }
00104 
00105 //_______________________________________________________________________
00106 void TMVA::MethodKNN::DeclareOptions() 
00107 {
00108    // MethodKNN options
00109  
00110    // fnkNN         = 20;     // number of k-nearest neighbors 
00111    // fBalanceDepth = 6;      // number of binary tree levels used for tree balancing
00112    // fScaleFrac    = 0.8;    // fraction of events used to compute variable width
00113    // fSigmaFact    = 1.0;    // scale factor for Gaussian sigma 
00114    // fKernel       = use polynomial (1-x^3)^3 or Gaussian kernel
00115    // fTrim         = false;  // use equal number of signal and background events
00116    // fUseKernel    = false;  // use polynomial kernel weight function
00117    // fUseWeight    = true;   // count events using weights
00118    // fUseLDA       = false
00119 
00120    DeclareOptionRef(fnkNN         = 20,     "nkNN",         "Number of k-nearest neighbors");
00121    DeclareOptionRef(fBalanceDepth = 6,      "BalanceDepth", "Binary tree balance depth");
00122    DeclareOptionRef(fScaleFrac    = 0.80,   "ScaleFrac",    "Fraction of events used to compute variable width");
00123    DeclareOptionRef(fSigmaFact    = 1.0,    "SigmaFact",    "Scale factor for sigma in Gaussian kernel");
00124    DeclareOptionRef(fKernel       = "Gaus", "Kernel",       "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
00125    DeclareOptionRef(fTrim         = kFALSE, "Trim",         "Use equal number of signal and background events");
00126    DeclareOptionRef(fUseKernel    = kFALSE, "UseKernel",    "Use polynomial kernel weight");
00127    DeclareOptionRef(fUseWeight    = kTRUE,  "UseWeight",    "Use weight to count kNN events");
00128    DeclareOptionRef(fUseLDA       = kFALSE, "UseLDA",       "Use local linear discriminant - experimental feature");
00129 }
00130 
00131 //_______________________________________________________________________
00132 void TMVA::MethodKNN::DeclareCompatibilityOptions() {
00133    MethodBase::DeclareCompatibilityOptions();
00134    DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
00135 }
00136 
00137 //_______________________________________________________________________
00138 void TMVA::MethodKNN::ProcessOptions() 
00139 {
00140    // process the options specified by the user
00141    if (!(fnkNN > 0)) {      
00142       fnkNN = 10;
00143       Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
00144    }
00145    if (fScaleFrac < 0.0) {      
00146       fScaleFrac = 0.0;
00147       Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
00148    }
00149    if (fScaleFrac > 1.0) {
00150       fScaleFrac = 1.0;
00151    }
00152    if (!(fBalanceDepth > 0)) {
00153       fBalanceDepth = 6;
00154       Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;      
00155    }
00156 
00157    Log() << kVERBOSE
00158          << "kNN options: \n" 
00159          << "  kNN = \n" << fnkNN
00160          << "  UseKernel = \n" << fUseKernel
00161          << "  SigmaFact = \n" << fSigmaFact
00162          << "  ScaleFrac = \n" << fScaleFrac
00163          << "  Kernel = \n" << fKernel
00164          << "  Trim = \n" << fTrim 
00165          << "  Optimize = " << fBalanceDepth << Endl;
00166 }
00167 
00168 //_______________________________________________________________________
00169 Bool_t TMVA::MethodKNN::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ )
00170 {
00171    // FDA can handle classification with 2 classes and regression with one regression-target
00172    if (type == Types::kClassification && numberClasses == 2) return kTRUE;
00173    if (type == Types::kRegression) return kTRUE;
00174    return kFALSE;
00175 }
00176 
00177 //_______________________________________________________________________
00178 void TMVA::MethodKNN::Init() 
00179 {
00180    // Initialization
00181 
00182    // fScaleFrac <= 0.0 then do not scale input variables
00183    // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
00184    
00185    fModule = new kNN::ModulekNN();
00186    fSumOfWeightsS = 0;
00187    fSumOfWeightsB = 0;
00188 }
00189 
00190 //_______________________________________________________________________
00191 void TMVA::MethodKNN::MakeKNN() 
00192 {
00193    // create kNN
00194    if (!fModule) {
00195       Log() << kFATAL << "ModulekNN is not created" << Endl;
00196    }
00197 
00198    fModule->Clear();
00199 
00200    std::string option;
00201    if (fScaleFrac > 0.0) {
00202       option += "metric";
00203    }
00204    if (fTrim) {
00205       option += "trim";
00206    }
00207 
00208    Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
00209 
00210    for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
00211       fModule->Add(*event);
00212    }
00213 
00214    // create binary tree
00215    fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
00216                  static_cast<UInt_t>(100.0*fScaleFrac),
00217                  option);
00218 }
00219 
00220 //_______________________________________________________________________
00221 void TMVA::MethodKNN::Train()
00222 {
00223    // kNN training
00224    Log() << kINFO << "<Train> start..." << Endl;
00225 
00226    if (IsNormalised()) {
00227       Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
00228       fScaleFrac = 0.0;
00229    }
00230    
00231    if (!fEvent.empty()) {
00232       Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
00233       fEvent.clear();
00234    }
00235    if (GetNVariables() < 1)
00236       Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
00237  
00238 
00239    Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
00240 
00241    for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
00242       // read the training event
00243       const Event*   evt_   = GetEvent(ievt);
00244       Double_t       weight = evt_->GetWeight();
00245 
00246       // in case event with neg weights are to be ignored
00247       if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;          
00248 
00249       kNN::VarVec vvec(GetNVariables(), 0.0);      
00250       for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
00251       
00252       Short_t event_type = 0;
00253 
00254       if (DataInfo().IsSignal(evt_)) { // signal type = 1
00255          fSumOfWeightsS += weight;
00256          event_type = 1;
00257       }
00258       else { // background type = 2
00259          fSumOfWeightsB += weight;
00260          event_type = 2;
00261       }
00262 
00263       //
00264       // Create event and add classification variables, weight, type and regression variables
00265       // 
00266       kNN::Event event_knn(vvec, weight, event_type);
00267       event_knn.SetTargets(evt_->GetTargets());
00268       fEvent.push_back(event_knn);
00269       
00270    }
00271    Log() << kINFO 
00272          << "Number of signal events " << fSumOfWeightsS << Endl
00273          << "Number of background events " << fSumOfWeightsB << Endl;
00274 
00275    // create kd-tree (binary tree) structure
00276    MakeKNN();
00277 }
00278 
00279 //_______________________________________________________________________
00280 Double_t TMVA::MethodKNN::GetMvaValue( Double_t* err, Double_t* errUpper )
00281 {
00282    // Compute classifier response
00283 
00284    // cannot determine error
00285    NoErrorCalc(err, errUpper);
00286 
00287    //
00288    // Define local variables
00289    //
00290    const Event *ev = GetEvent();
00291    const Int_t nvar = GetNVariables();
00292    const Double_t weight = ev->GetWeight();
00293    const UInt_t knn = static_cast<UInt_t>(fnkNN);
00294 
00295    kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
00296    
00297    for (Int_t ivar = 0; ivar < nvar; ++ivar) {
00298       vvec[ivar] = ev->GetValue(ivar);
00299    }
00300 
00301    // search for fnkNN+2 nearest neighbors, pad with two 
00302    // events to avoid Monte-Carlo events with zero distance
00303    // most of CPU time is spent in this recursive function
00304    const kNN::Event event_knn(vvec, weight, 3);
00305    fModule->Find(event_knn, knn + 2);
00306 
00307    const kNN::List &rlist = fModule->GetkNNList();
00308    if (rlist.size() != knn + 2) {
00309       Log() << kFATAL << "kNN result list is empty" << Endl;
00310       return -100.0;  
00311    }
00312    
00313    if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
00314 
00315    //
00316    // Set flags for kernel option=Gaus, Poln
00317    //
00318    Bool_t use_gaus = false, use_poln = false;
00319    if (fUseKernel) {
00320       if      (fKernel == "Gaus") use_gaus = true;
00321       else if (fKernel == "Poln") use_poln = true;
00322    }
00323 
00324    //
00325    // Compute radius for polynomial kernel
00326    //
00327    Double_t kradius = -1.0;
00328    if (use_poln) {
00329       kradius = MethodKNN::getKernelRadius(rlist);
00330 
00331       if (!(kradius > 0.0)) {
00332          Log() << kFATAL << "kNN radius is not positive" << Endl;
00333          return -100.0; 
00334       }
00335       
00336       kradius = 1.0/TMath::Sqrt(kradius);
00337    }
00338    
00339    //
00340    // Compute RMS of variable differences for Gaussian sigma
00341    //
00342    std::vector<Double_t> rms_vec;
00343    if (use_gaus) {
00344       rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
00345 
00346       if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
00347          Log() << kFATAL << "Failed to compute RMS vector" << Endl;
00348          return -100.0; 
00349       }            
00350    }
00351 
00352    UInt_t count_all = 0;
00353    Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
00354 
00355    for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
00356 
00357       // get reference to current node to make code more readable
00358       const kNN::Node<kNN::Event> &node = *(lit->first);
00359       
00360       // Warn about Monte-Carlo event with zero distance
00361       // this happens when this query event is also in learning sample
00362       if (lit->second < 0.0) {
00363          Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
00364       }
00365       else if (!(lit->second > 0.0)) {
00366          Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
00367       }
00368       
00369       // get event weight and scale weight by kernel function
00370       Double_t evweight = node.GetWeight();
00371       if      (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
00372       else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
00373       
00374       if (fUseWeight) weight_all += evweight;
00375       else          ++weight_all;
00376 
00377       if (node.GetEvent().GetType() == 1) { // signal type = 1
00378          if (fUseWeight) weight_sig += evweight;
00379          else          ++weight_sig;
00380       }
00381       else if (node.GetEvent().GetType() == 2) { // background type = 2
00382          if (fUseWeight) weight_bac += evweight;
00383          else          ++weight_bac;
00384       }
00385       else {
00386          Log() << kFATAL << "Unknown type for training event" << Endl;
00387       }
00388       
00389       // use only fnkNN events
00390       ++count_all;
00391 
00392       if (count_all >= knn) {
00393          break;
00394       }      
00395    }
00396 
00397    // check that total number of events or total weight sum is positive
00398    if (!(count_all > 0)) {
00399       Log() << kFATAL << "Size kNN result list is not positive" << Endl;
00400       return -100.0;
00401    }
00402    
00403    // check that number of events matches number of k in knn 
00404    if (count_all < knn) {
00405       Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
00406    }
00407    
00408    // Check that total weight is positive
00409    if (!(weight_all > 0.0)) {
00410       Log() << kFATAL << "kNN result total weight is not positive" << Endl;
00411       return -100.0;
00412    }
00413    
00414    return weight_sig/weight_all;
00415 }
00416 
00417 //_______________________________________________________________________
00418 const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
00419 {
00420    //
00421    // Return vector of averages for target values of k-nearest neighbors.
00422    // Use own copy of the regression vector, I do not like using a pointer to vector.
00423    //
00424    if( fRegressionReturnVal == 0 )
00425       fRegressionReturnVal = new std::vector<Float_t>;
00426    else 
00427       fRegressionReturnVal->clear();
00428 
00429    //
00430    // Define local variables
00431    //
00432    const Event *evt = GetEvent();
00433    const Int_t nvar = GetNVariables();
00434    const UInt_t knn = static_cast<UInt_t>(fnkNN);
00435    std::vector<float> reg_vec;
00436 
00437    kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
00438    
00439    for (Int_t ivar = 0; ivar < nvar; ++ivar) {
00440       vvec[ivar] = evt->GetValue(ivar);
00441    }   
00442 
00443    // search for fnkNN+2 nearest neighbors, pad with two 
00444    // events to avoid Monte-Carlo events with zero distance
00445    // most of CPU time is spent in this recursive function
00446    const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
00447    fModule->Find(event_knn, knn + 2);
00448 
00449    const kNN::List &rlist = fModule->GetkNNList();
00450    if (rlist.size() != knn + 2) {
00451       Log() << kFATAL << "kNN result list is empty" << Endl;
00452       return *fRegressionReturnVal;
00453    }
00454 
00455    // compute regression values
00456    Double_t weight_all = 0;
00457    UInt_t count_all = 0;
00458 
00459    for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
00460 
00461       // get reference to current node to make code more readable
00462       const kNN::Node<kNN::Event> &node = *(lit->first);
00463       const kNN::VarVec &tvec = node.GetEvent().GetTargets();
00464       const Double_t weight = node.GetEvent().GetWeight();
00465 
00466       if (reg_vec.empty()) {
00467          reg_vec= kNN::VarVec(tvec.size(), 0.0);
00468       }
00469       
00470       for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
00471          if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
00472          else            reg_vec[ivar] += tvec[ivar];
00473       }
00474 
00475       if (fUseWeight) weight_all += weight;
00476       else          ++weight_all;
00477 
00478       // use only fnkNN events
00479       ++count_all;
00480 
00481       if (count_all == knn) {
00482          break;
00483       }
00484    }
00485 
00486    // check that number of events matches number of k in knn 
00487    if (!(weight_all > 0.0)) {
00488       Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
00489       return *fRegressionReturnVal;
00490    }
00491 
00492    for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
00493       reg_vec[ivar] /= weight_all;
00494    }
00495 
00496    // copy result
00497    fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
00498 
00499    return *fRegressionReturnVal;
00500 }
00501 
00502 //_______________________________________________________________________
00503 const TMVA::Ranking* TMVA::MethodKNN::CreateRanking() 
00504 {
00505    // no ranking available
00506    return 0;
00507 }
00508 
00509 //_______________________________________________________________________
00510 void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
00511    // write weights to XML
00512 
00513    void* wght = gTools().AddChild(parent, "Weights");
00514    gTools().AddAttr(wght,"NEvents",fEvent.size());
00515    if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
00516    if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
00517 
00518    for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
00519 
00520       std::stringstream s("");
00521       s.precision( 16 );
00522       for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
00523          if (ivar>0) s << " ";
00524          s << std::scientific << event->GetVar(ivar);
00525       }
00526 
00527       for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
00528          s << " " << std::scientific << event->GetTgt(itgt);
00529       }
00530 
00531       void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
00532       gTools().AddAttr(evt,"Type", event->GetType());
00533       gTools().AddAttr(evt,"Weight", event->GetWeight());
00534    }
00535 }
00536 
00537 //_______________________________________________________________________
00538 void TMVA::MethodKNN::ReadWeightsFromXML( void* wghtnode ) {
00539 
00540    void* ch = gTools().GetChild(wghtnode); // first event
00541    UInt_t nvar = 0, ntgt = 0;
00542    gTools().ReadAttr( wghtnode, "NVar", nvar );
00543    gTools().ReadAttr( wghtnode, "NTgt", ntgt );
00544 
00545 
00546    Short_t evtType(0);
00547    Double_t evtWeight(0);
00548 
00549    while (ch) {
00550       // build event
00551       kNN::VarVec vvec(nvar, 0);
00552       kNN::VarVec tvec(ntgt, 0);
00553 
00554       gTools().ReadAttr( ch, "Type",   evtType   );
00555       gTools().ReadAttr( ch, "Weight", evtWeight );
00556       std::stringstream s( gTools().GetContent(ch) );
00557       
00558       for(UInt_t ivar=0; ivar<nvar; ivar++)
00559          s >> vvec[ivar];
00560 
00561       for(UInt_t itgt=0; itgt<ntgt; itgt++)
00562          s >> tvec[itgt];
00563 
00564       ch = gTools().GetNextChild(ch);
00565 
00566       kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
00567       fEvent.push_back(event_knn);
00568    }
00569 
00570    // create kd-tree (binary tree) structure
00571    MakeKNN();
00572 }
00573 
00574 //_______________________________________________________________________
00575 void TMVA::MethodKNN::ReadWeightsFromStream(istream& is)
00576 {
00577    // read the weights
00578    Log() << kINFO << "Starting ReadWeightsFromStream(istream& is) function..." << Endl;
00579 
00580    if (!fEvent.empty()) {
00581       Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
00582       fEvent.clear();
00583    }
00584 
00585    UInt_t nvar = 0;
00586 
00587    while (!is.eof()) {
00588       std::string line;
00589       std::getline(is, line);
00590       
00591       if (line.empty() || line.find("#") != std::string::npos) {
00592          continue;
00593       }
00594       
00595       UInt_t count = 0;
00596       std::string::size_type pos=0;
00597       while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
00598 
00599       if (nvar == 0) {
00600          nvar = count - 2;
00601       }
00602       if (count < 3 || nvar != count - 2) {
00603          Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
00604       }
00605 
00606       Int_t ievent = -1, type = -1;
00607       Double_t weight = -1.0;
00608       
00609       kNN::VarVec vvec(nvar, 0.0);
00610       
00611       UInt_t vcount = 0;
00612       std::string::size_type prev = 0;
00613       
00614       for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
00615          if (line[ipos] != ',' && ipos + 1 != line.size()) {
00616             continue;
00617          }
00618          
00619          if (!(ipos > prev)) {
00620             Log() << kFATAL << "Wrong substring limits" << Endl;
00621          }
00622          
00623          std::string vstring = line.substr(prev, ipos - prev);
00624          if (ipos + 1 == line.size()) {
00625             vstring = line.substr(prev, ipos - prev + 1);
00626          }
00627          
00628          if (vstring.empty()) {
00629             Log() << kFATAL << "Failed to parse string" << Endl;
00630          }
00631          
00632          if (vcount == 0) {
00633             ievent = std::atoi(vstring.c_str());
00634          }
00635          else if (vcount == 1) {
00636             type = std::atoi(vstring.c_str());
00637          }
00638          else if (vcount == 2) {
00639             weight = std::atof(vstring.c_str());
00640          }
00641          else if (vcount - 3 < vvec.size()) {
00642             vvec[vcount - 3] = std::atof(vstring.c_str());
00643          }
00644          else {
00645             Log() << kFATAL << "Wrong variable count" << Endl;
00646          }
00647          
00648          prev = ipos + 1;
00649          ++vcount;
00650       }
00651       
00652       fEvent.push_back(kNN::Event(vvec, weight, type));
00653    }
00654    
00655    Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;   
00656 
00657    // create kd-tree (binary tree) structure
00658    MakeKNN();
00659 }
00660 
00661 //-------------------------------------------------------------------------------------------
00662 void TMVA::MethodKNN::WriteWeightsToStream(TFile &rf) const
00663 { 
00664    // save weights to ROOT file
00665    Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
00666    
00667    if (fEvent.empty()) {
00668       Log() << kWARNING << "MethodKNN contains no events " << Endl;
00669       return;
00670    }
00671 
00672    kNN::Event *event = new kNN::Event();
00673    TTree *tree = new TTree("knn", "event tree");
00674    tree->SetDirectory(0);
00675    tree->Branch("event", "TMVA::kNN::Event", &event);
00676 
00677    Double_t size = 0.0;
00678    for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
00679       (*event) = (*it);
00680       size += tree->Fill();
00681    }
00682 
00683    // !!! hard coded tree name !!!
00684    rf.WriteTObject(tree, "knn", "Overwrite");
00685 
00686    // scale to MegaBytes
00687    size /= 1048576.0;
00688 
00689    Log() << kINFO << "Wrote " << size << "MB and "  << fEvent.size() 
00690          << " events to ROOT file" << Endl;
00691    
00692    delete tree;
00693    delete event; 
00694 }
00695 
00696 //-------------------------------------------------------------------------------------------
00697 void TMVA::MethodKNN::ReadWeightsFromStream(TFile &rf)
00698 { 
00699    // read weights from ROOT file
00700    Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
00701 
00702    if (!fEvent.empty()) {
00703       Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
00704       fEvent.clear();
00705    }
00706 
00707    // !!! hard coded tree name !!!
00708    TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
00709    if (!tree) {
00710       Log() << kFATAL << "Failed to find knn tree" << Endl;
00711       return;
00712    }
00713 
00714    kNN::Event *event = new kNN::Event();
00715    tree->SetBranchAddress("event", &event);
00716 
00717    const Int_t nevent = tree->GetEntries();
00718 
00719    Double_t size = 0.0;
00720    for (Int_t i = 0; i < nevent; ++i) {
00721       size += tree->GetEntry(i);
00722       fEvent.push_back(*event);
00723    }
00724 
00725    // scale to MegaBytes
00726    size /= 1048576.0;
00727 
00728    Log() << kINFO << "Read " << size << "MB and "  << fEvent.size() 
00729          << " events from ROOT file" << Endl;
00730 
00731    delete event;
00732 
00733    // create kd-tree (binary tree) structure
00734    MakeKNN();
00735 }
00736 
00737 //_______________________________________________________________________
00738 void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
00739 {
00740    // write specific classifier response
00741    fout << "   // not implemented for class: \"" << className << "\"" << std::endl;
00742    fout << "};" << std::endl;
00743 }
00744 
00745 //_______________________________________________________________________
00746 void TMVA::MethodKNN::GetHelpMessage() const
00747 {
00748    // get help message text
00749    //
00750    // typical length of text line: 
00751    //         "|--------------------------------------------------------------|"
00752    Log() << Endl;
00753    Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00754    Log() << Endl;
00755    Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
00756          << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
00757          << "training events for which a classification category/regression target is known. " << Endl
00758          << "The k-NN method compares a test event to all training events using a distance " << Endl
00759          << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
00760          << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
00761          << "quick search for the k events with shortest distance to the test event. The method" << Endl
00762          << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
00763          << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
00764          << "between 0 and 1." << Endl;
00765 
00766    Log() << Endl;
00767    Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: " 
00768          << gTools().Color("reset") << Endl;
00769    Log() << Endl;
00770    Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
00771          << "neighborhood around the test event. The method assumes that the density of the " << Endl
00772          << "signal and background events is uniform and constant within the neighborhood. " << Endl
00773          << "k is an adjustable parameter and it determines an average size of the " << Endl
00774          << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
00775          << "fluctuations and large (greater than 100) values might not sufficiently capture  " << Endl
00776          << "local differences between events in the training set. The speed of the k-NN" << Endl
00777          << "method also increases with larger values of k. " << Endl;   
00778    Log() << Endl;
00779    Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
00780          << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
00781          << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
00782          << "equal among all the input variables." << Endl;
00783 
00784    Log() << Endl;
00785    Log() << gTools().Color("bold") << "--- Additional configuration options: " 
00786          << gTools().Color("reset") << Endl;
00787    Log() << Endl;
00788    Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
00789          << "response. The kernel re-weights events using a distance to the test event." << Endl;
00790 }
00791 
00792 //_______________________________________________________________________
00793 Double_t TMVA::MethodKNN::PolnKernel(const Double_t value) const
00794 {
00795    // polynomial kernel
00796    const Double_t avalue = TMath::Abs(value);
00797 
00798    if (!(avalue < 1.0)) {
00799       return 0.0;
00800    }
00801 
00802    const Double_t prod = 1.0 - avalue * avalue * avalue;
00803 
00804    return (prod * prod * prod);
00805 }
00806 
00807 //_______________________________________________________________________
00808 Double_t TMVA::MethodKNN::GausKernel(const kNN::Event &event_knn,
00809                                      const kNN::Event &event, const std::vector<Double_t> &svec) const
00810 {
00811    // Gaussian kernel
00812 
00813    if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
00814       Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
00815       return 0.0;
00816    }
00817 
00818    //
00819    // compute exponent
00820    //
00821    double sum_exp = 0.0;
00822 
00823    for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
00824 
00825       const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
00826       const Double_t sigm_ = svec[ivar];
00827       if (!(sigm_ > 0.0)) {
00828          Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
00829          return 0.0;
00830       }
00831 
00832       sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
00833    }
00834 
00835    //
00836    // Return unnormalized(!) Gaussian function, because normalization
00837    // cancels for the ratio of weights.
00838    //
00839 
00840    return std::exp(-sum_exp);
00841 }
00842 
00843 //_______________________________________________________________________
00844 Double_t TMVA::MethodKNN::getKernelRadius(const kNN::List &rlist) const
00845 {
00846    //
00847    // Get polynomial kernel radius
00848    //
00849    Double_t kradius = -1.0;
00850    UInt_t kcount = 0;
00851    const UInt_t knn = static_cast<UInt_t>(fnkNN);
00852 
00853    for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
00854       {
00855          if (!(lit->second > 0.0)) continue;         
00856       
00857          if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
00858       
00859          ++kcount;
00860          if (kcount >= knn) break;
00861       }
00862    
00863    return kradius;
00864 }
00865 
00866 //_______________________________________________________________________
00867 const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
00868 {
00869    //
00870    // Get polynomial kernel radius
00871    //
00872    std::vector<Double_t> rvec;
00873    UInt_t kcount = 0;
00874    const UInt_t knn = static_cast<UInt_t>(fnkNN);
00875 
00876    for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
00877       {
00878          if (!(lit->second > 0.0)) continue;         
00879       
00880          const kNN::Node<kNN::Event> *node_ = lit -> first;
00881          const kNN::Event &event_ = node_-> GetEvent();
00882       
00883          if (rvec.empty()) {
00884             rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
00885          }
00886          else if (rvec.size() != event_.GetNVar()) {
00887             Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
00888             rvec.clear();
00889             return rvec;
00890          }
00891 
00892          for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
00893             const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
00894             rvec[ivar] += diff_*diff_;
00895          }
00896 
00897          ++kcount;
00898          if (kcount >= knn) break;
00899       }
00900 
00901    if (kcount < 1) {
00902       Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
00903       rvec.clear();
00904       return rvec;
00905    }
00906 
00907    for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
00908       if (!(rvec[ivar] > 0.0)) {
00909          Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
00910          rvec.clear();
00911          return rvec;
00912       }
00913 
00914       rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
00915    }   
00916    
00917    return rvec;
00918 }
00919 
00920 //_______________________________________________________________________
00921 Double_t TMVA::MethodKNN::getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
00922 {
00923    LDAEvents sig_vec, bac_vec;
00924 
00925    for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
00926        
00927       // get reference to current node to make code more readable
00928       const kNN::Node<kNN::Event> &node = *(lit->first);
00929       const kNN::VarVec &tvec = node.GetEvent().GetVars();
00930 
00931       if (node.GetEvent().GetType() == 1) { // signal type = 1
00932          sig_vec.push_back(tvec);
00933       }
00934       else if (node.GetEvent().GetType() == 2) { // background type = 2
00935          bac_vec.push_back(tvec);
00936       }
00937       else {
00938          Log() << kFATAL << "Unknown type for training event" << Endl;
00939       }       
00940    }
00941 
00942    fLDA.Initialize(sig_vec, bac_vec);
00943     
00944    return fLDA.GetProb(event_knn.GetVars(), 1);
00945 }

Generated on Tue Jul 5 15:25:03 2011 for ROOT_528-00b_version by  doxygen 1.5.1